File size: 1,438 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
from torch.utils.data import Dataset

from .builder import DATASETS, build_dataset


@DATASETS.register_module()
class AdversarialDataset(Dataset):
    """Mix Dataset for the adversarial training in 3D human mesh estimation
    task.

    The dataset combines data from two datasets and
    return a dict containing data from two datasets.
    Args:
        train_dataset (:obj:`Dataset`): Dataset for 3D human mesh estimation.
        adv_dataset (:obj:`Dataset`): Dataset for adversarial learning.
    """
    def __init__(self, train_dataset: Dataset, adv_dataset: Dataset):
        super().__init__()
        self.train_dataset = build_dataset(train_dataset)
        self.adv_dataset = build_dataset(adv_dataset)
        self.num_train_data = len(self.train_dataset)
        self.num_adv_data = len(self.adv_dataset)

    def __len__(self):
        """Get the size of the dataset."""
        return self.num_train_data

    def __getitem__(self, idx: int):
        """Given index, get the data from train dataset and randomly sample an
        item from adversarial dataset.

        Return a dict containing data from train and adversarial dataset.
        """
        data = self.train_dataset[idx]
        adv_idx = np.random.randint(low=0, high=self.num_adv_data, dtype=int)
        adv_data = self.adv_dataset[adv_idx]
        for k, v in adv_data.items():
            data['adv_' + k] = v
        return data