Spaces:
Sleeping
Sleeping
File size: 1,438 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import numpy as np
from torch.utils.data import Dataset
from .builder import DATASETS, build_dataset
@DATASETS.register_module()
class AdversarialDataset(Dataset):
"""Mix Dataset for the adversarial training in 3D human mesh estimation
task.
The dataset combines data from two datasets and
return a dict containing data from two datasets.
Args:
train_dataset (:obj:`Dataset`): Dataset for 3D human mesh estimation.
adv_dataset (:obj:`Dataset`): Dataset for adversarial learning.
"""
def __init__(self, train_dataset: Dataset, adv_dataset: Dataset):
super().__init__()
self.train_dataset = build_dataset(train_dataset)
self.adv_dataset = build_dataset(adv_dataset)
self.num_train_data = len(self.train_dataset)
self.num_adv_data = len(self.adv_dataset)
def __len__(self):
"""Get the size of the dataset."""
return self.num_train_data
def __getitem__(self, idx: int):
"""Given index, get the data from train dataset and randomly sample an
item from adversarial dataset.
Return a dict containing data from train and adversarial dataset.
"""
data = self.train_dataset[idx]
adv_idx = np.random.randint(low=0, high=self.num_adv_data, dtype=int)
adv_data = self.adv_dataset[adv_idx]
for k, v in adv_data.items():
data['adv_' + k] = v
return data
|