AiOS / detrsmpl /data /datasets /mixed_dataset.py
ttxskk
update
d7e58f0
raw
history blame
1.67 kB
from typing import Optional, Union
import numpy as np
from torch.utils.data import ConcatDataset, Dataset, WeightedRandomSampler
from .builder import DATASETS, build_dataset
@DATASETS.register_module()
class MixedDataset(Dataset):
"""Mixed Dataset.
Args:
config (list): the list of different datasets.
partition (list): the ratio of datasets in each batch.
num_data (int | None, optional): if num_data is not None, the number
of iterations is set to this fixed value. Otherwise, the number of
iterations is set to the maximum size of each single dataset.
Default: None.
"""
def __init__(self,
configs: list,
partition: list,
num_data: Optional[Union[int, None]] = None):
"""Load data from multiple datasets."""
assert min(partition) >= 0
datasets = [build_dataset(cfg) for cfg in configs]
self.dataset = ConcatDataset(datasets)
if num_data is not None:
self.length = num_data
else:
self.length = max(len(ds) for ds in datasets)
weights = [
np.ones(len(ds)) * p / len(ds)
for (p, ds) in zip(partition, datasets)
]
weights = np.concatenate(weights, axis=0)
self.sampler = WeightedRandomSampler(weights, 1)
def __len__(self):
"""Get the size of the dataset."""
return self.length
def __getitem__(self, idx):
"""Given index, sample the data from multiple datasets with the given
proportion."""
idx_new = list(self.sampler)[0]
return self.dataset[idx_new]