import random import numpy as np from torch.utils.data.dataset import Dataset from config.config import cfg class MultipleDatasets(Dataset): def __init__(self, dbs, partition, make_same_len=True, total_len=None, verbose=False): self.dbs = dbs self.db_num = len(self.dbs) self.max_db_data_num = max([len(db) for db in dbs]) self.db_len_cumsum = np.cumsum([len(db) for db in dbs]) self.make_same_len = make_same_len # self.partition = partition self.partition = {k: v for k, v in sorted(partition.items(), key=lambda item: item[1])} self.dataset = {} for db in dbs: self.dataset.update({db.__class__.__name__: db}) if verbose: print('datasets:', [len(self.dbs[i]) for i in range(self.db_num)]) print( f'Sample Ratio: {self.partition}') def __len__(self): return self.max_db_data_num def __getitem__(self, index): p = np.random.rand() v = list(self.partition.values()) k = list(self.partition.keys()) for i,v_i in enumerate(v): if p<=v_i: return self.dataset[k[i]][index % len(self.dataset[k[i]])] import random import numpy as np from torch.utils.data.dataset import Dataset class MultipleDatasets_debug(Dataset): def __init__(self, dbs, make_same_len=True, total_len=None, verbose=False): self.dbs = dbs self.db_num = len(self.dbs) self.max_db_data_num = max([len(db) for db in dbs]) self.db_len_cumsum = np.cumsum([len(db) for db in dbs]) self.make_same_len = make_same_len if total_len == 'auto': self.total_len = self.db_len_cumsum[-1] self.auto_total_len = True else: self.total_len = total_len self.auto_total_len = False if total_len is not None: self.per_db_len = self.total_len // self.db_num if verbose: print('datasets:', [len(self.dbs[i]) for i in range(self.db_num)]) print( f'Auto total length: {self.auto_total_len}, {self.total_len}') def __len__(self): # all dbs have the same length if self.make_same_len: if self.total_len is None: # match the longest length return self.max_db_data_num * self.db_num else: # each dataset has the same length and total len is fixed return self.total_len else: # each db has different length, simply concat return sum([len(db) for db in self.dbs]) def __getitem__(self, index): if self.make_same_len: if self.total_len is None: # match the longest length db_idx = index // self.max_db_data_num data_idx = index % self.max_db_data_num if data_idx >= len(self.dbs[db_idx]) * ( self.max_db_data_num // len(self.dbs[db_idx])): # last batch: random sampling data_idx = random.randint(0, len(self.dbs[db_idx]) - 1) else: # before last batch: use modular data_idx = data_idx % len(self.dbs[db_idx]) else: db_idx = index // self.per_db_len data_idx = index % self.per_db_len if db_idx > (self.db_num - 1): # last batch: randomly choose one dataset db_idx = random.randint(0, self.db_num - 1) if len(self.dbs[db_idx]) < self.per_db_len and \ data_idx >= len(self.dbs[db_idx]) * (self.per_db_len // len(self.dbs[db_idx])): # last batch: random sampling in this dataset data_idx = random.randint(0, len(self.dbs[db_idx]) - 1) else: # before last batch: use modular data_idx = data_idx % len(self.dbs[db_idx]) else: for i in range(self.db_num): if index < self.db_len_cumsum[i]: db_idx = i break if db_idx == 0: data_idx = index else: data_idx = index - self.db_len_cumsum[db_idx - 1] return self.dbs[db_idx][data_idx]