File size: 1,669 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Optional, Union

import numpy as np
from torch.utils.data import ConcatDataset, Dataset, WeightedRandomSampler

from .builder import DATASETS, build_dataset


@DATASETS.register_module()
class MixedDataset(Dataset):
    """Mixed Dataset.

    Args:
        config (list): the list of different datasets.
        partition (list): the ratio of datasets in each batch.
        num_data (int | None, optional): if num_data is not None, the number
            of iterations is set to this fixed value. Otherwise, the number of
            iterations is set to the maximum size of each single dataset.
            Default: None.
    """
    def __init__(self,
                 configs: list,
                 partition: list,
                 num_data: Optional[Union[int, None]] = None):
        """Load data from multiple datasets."""
        assert min(partition) >= 0
        datasets = [build_dataset(cfg) for cfg in configs]
        self.dataset = ConcatDataset(datasets)
        if num_data is not None:
            self.length = num_data
        else:
            self.length = max(len(ds) for ds in datasets)
        weights = [
            np.ones(len(ds)) * p / len(ds)
            for (p, ds) in zip(partition, datasets)
        ]
        weights = np.concatenate(weights, axis=0)
        self.sampler = WeightedRandomSampler(weights, 1)

    def __len__(self):
        """Get the size of the dataset."""
        return self.length

    def __getitem__(self, idx):
        """Given index, sample the data from multiple datasets with the given
        proportion."""
        idx_new = list(self.sampler)[0]
        return self.dataset[idx_new]