File size: 4,459 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import random
import numpy as np
from torch.utils.data.dataset import Dataset
from config.config import cfg

class MultipleDatasets(Dataset):
    def __init__(self, 
                 dbs,
                 partition,
                 make_same_len=True, 
                 total_len=None, 
                 verbose=False):
        self.dbs = dbs
        self.db_num = len(self.dbs)
        self.max_db_data_num = max([len(db) for db in dbs])
        self.db_len_cumsum = np.cumsum([len(db) for db in dbs])
        self.make_same_len = make_same_len
        # self.partition = partition
        self.partition = {k: v for k, v in sorted(partition.items(), key=lambda item: item[1])}
        self.dataset = {}
        for db in dbs:
            self.dataset.update({db.__class__.__name__: db})

        if verbose:
            print('datasets:', [len(self.dbs[i]) for i in range(self.db_num)])
            print(
                f'Sample Ratio: {self.partition}')

    def __len__(self):
        return self.max_db_data_num

    def __getitem__(self, index):
        p = np.random.rand()
        v = list(self.partition.values())
        k = list(self.partition.keys())
        for i,v_i in enumerate(v):
            if p<=v_i:
                return self.dataset[k[i]][index % len(self.dataset[k[i]])]


import random
import numpy as np
from torch.utils.data.dataset import Dataset


class MultipleDatasets_debug(Dataset):
    def __init__(self, dbs, make_same_len=True, total_len=None, verbose=False):
        self.dbs = dbs
        self.db_num = len(self.dbs)
        self.max_db_data_num = max([len(db) for db in dbs])
        self.db_len_cumsum = np.cumsum([len(db) for db in dbs])
        self.make_same_len = make_same_len

        if total_len == 'auto':
            self.total_len = self.db_len_cumsum[-1]
            self.auto_total_len = True
        else:
            self.total_len = total_len
            self.auto_total_len = False

        if total_len is not None:
            self.per_db_len = self.total_len // self.db_num
        if verbose:
            print('datasets:', [len(self.dbs[i]) for i in range(self.db_num)])
            print(
                f'Auto total length: {self.auto_total_len}, {self.total_len}')

    def __len__(self):
        # all dbs have the same length
        if self.make_same_len:
            if self.total_len is None:
                # match the longest length
                return self.max_db_data_num * self.db_num
            else:
                # each dataset has the same length and total len is fixed
                return self.total_len
        else:
            # each db has different length, simply concat
            return sum([len(db) for db in self.dbs])

    def __getitem__(self, index):
        if self.make_same_len:
            if self.total_len is None:
                # match the longest length
                db_idx = index // self.max_db_data_num
                data_idx = index % self.max_db_data_num
                if data_idx >= len(self.dbs[db_idx]) * (
                        self.max_db_data_num //
                        len(self.dbs[db_idx])):  # last batch: random sampling
                    data_idx = random.randint(0, len(self.dbs[db_idx]) - 1)
                else:  # before last batch: use modular
                    data_idx = data_idx % len(self.dbs[db_idx])
            else:
                db_idx = index // self.per_db_len
                data_idx = index % self.per_db_len
                if db_idx > (self.db_num - 1):
                    # last batch: randomly choose one dataset
                    db_idx = random.randint(0, self.db_num - 1)

                if len(self.dbs[db_idx]) < self.per_db_len  and \
                        data_idx >= len(self.dbs[db_idx]) * (self.per_db_len  // len(self.dbs[db_idx])):
                    # last batch: random sampling in this dataset
                    data_idx = random.randint(0, len(self.dbs[db_idx]) - 1)
                else:
                    # before last batch: use modular
                    data_idx = data_idx % len(self.dbs[db_idx])

        else:
            for i in range(self.db_num):
                if index < self.db_len_cumsum[i]:
                    db_idx = i
                    break
            if db_idx == 0:
                data_idx = index
            else:
                data_idx = index - self.db_len_cumsum[db_idx - 1]

        return self.dbs[db_idx][data_idx]