Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,506 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import random
from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data.sampler import (
BatchSampler,
RandomSampler,
Sampler,
SequentialSampler,
)
class ScheduledSampler(Sampler):
"""A sampler that samples data from a given concat-dataset.
Args:
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
batch_size (int): batch size
holistic_shuffle (bool): whether to shuffle the whole dataset or not
logger (logging.Logger): logger to print warning message
Usage:
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
>>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
[3, 4, 5, 0, 1, 2, 6, 7, 8]
"""
def __init__(
self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
):
if not isinstance(concat_dataset, ConcatDataset):
raise ValueError(
"concat_dataset must be an instance of ConcatDataset, but got {}".format(
type(concat_dataset)
)
)
if not isinstance(batch_size, int):
raise ValueError(
"batch_size must be an integer, but got {}".format(type(batch_size))
)
if not isinstance(holistic_shuffle, bool):
raise ValueError(
"holistic_shuffle must be a boolean, but got {}".format(
type(holistic_shuffle)
)
)
self.concat_dataset = concat_dataset
self.batch_size = batch_size
self.holistic_shuffle = holistic_shuffle
affected_dataset_name = []
affected_dataset_len = []
for dataset in concat_dataset.datasets:
dataset_len = len(dataset)
dataset_name = dataset.get_dataset_name()
if dataset_len < batch_size:
affected_dataset_name.append(dataset_name)
affected_dataset_len.append(dataset_len)
self.type = type
for dataset_name, dataset_len in zip(
affected_dataset_name, affected_dataset_len
):
if not type == "valid":
logger.warning(
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
type, dataset_name, dataset_len, batch_size
)
)
def __len__(self):
# the number of batches with drop last
num_of_batches = sum(
[
math.floor(len(dataset) / self.batch_size)
for dataset in self.concat_dataset.datasets
]
)
return num_of_batches * self.batch_size
def __iter__(self):
iters = []
for dataset in self.concat_dataset.datasets:
iters.append(
SequentialSampler(dataset).__iter__()
if self.holistic_shuffle
else RandomSampler(dataset).__iter__()
)
init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
output_batches = []
for dataset_idx in range(len(self.concat_dataset.datasets)):
cur_batch = []
for idx in iters[dataset_idx]:
cur_batch.append(idx + init_indices[dataset_idx])
if len(cur_batch) == self.batch_size:
output_batches.append(cur_batch)
cur_batch = []
if self.type == "valid" and len(cur_batch) > 0:
output_batches.append(cur_batch)
cur_batch = []
# force drop last in training
random.shuffle(output_batches)
output_indices = [item for sublist in output_batches for item in sublist]
return iter(output_indices)
def build_samplers(concat_dataset: Dataset, cfg, logger, type):
sampler = ScheduledSampler(
concat_dataset,
cfg.train.batch_size,
cfg.train.sampler.holistic_shuffle,
logger,
type,
)
batch_sampler = BatchSampler(
sampler,
cfg.train.batch_size,
cfg.train.sampler.drop_last if not type == "valid" else False,
)
return sampler, batch_sampler
|