AiOS / detrsmpl /data /datasets /adversarial_dataset.py
ttxskk
update
d7e58f0
raw
history blame
1.44 kB
import numpy as np
from torch.utils.data import Dataset
from .builder import DATASETS, build_dataset
@DATASETS.register_module()
class AdversarialDataset(Dataset):
"""Mix Dataset for the adversarial training in 3D human mesh estimation
task.
The dataset combines data from two datasets and
return a dict containing data from two datasets.
Args:
train_dataset (:obj:`Dataset`): Dataset for 3D human mesh estimation.
adv_dataset (:obj:`Dataset`): Dataset for adversarial learning.
"""
def __init__(self, train_dataset: Dataset, adv_dataset: Dataset):
super().__init__()
self.train_dataset = build_dataset(train_dataset)
self.adv_dataset = build_dataset(adv_dataset)
self.num_train_data = len(self.train_dataset)
self.num_adv_data = len(self.adv_dataset)
def __len__(self):
"""Get the size of the dataset."""
return self.num_train_data
def __getitem__(self, idx: int):
"""Given index, get the data from train dataset and randomly sample an
item from adversarial dataset.
Return a dict containing data from train and adversarial dataset.
"""
data = self.train_dataset[idx]
adv_idx = np.random.randint(low=0, high=self.num_adv_data, dtype=int)
adv_data = self.adv_dataset[adv_idx]
for k, v in adv_data.items():
data['adv_' + k] = v
return data