orhir's picture
Upload 97 files
241adf2
raw
history blame
No virus
1.92 kB
from mmcv.utils import build_from_cfg
from mmpose.datasets.builder import DATASETS
from mmpose.datasets.dataset_wrappers import RepeatDataset
from torch.utils.data.dataset import ConcatDataset
def _concat_cfg(cfg):
replace = ['ann_file', 'img_prefix']
channels = ['num_joints', 'dataset_channel']
concat_cfg = []
for i in range(len(cfg['type'])):
cfg_tmp = cfg.deepcopy()
cfg_tmp['type'] = cfg['type'][i]
for item in replace:
assert item in cfg_tmp
assert len(cfg['type']) == len(cfg[item]), (cfg[item])
cfg_tmp[item] = cfg[item][i]
for item in channels:
assert item in cfg_tmp['data_cfg']
assert len(cfg['type']) == len(cfg['data_cfg'][item])
cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i]
concat_cfg.append(cfg_tmp)
return concat_cfg
def _check_vaild(cfg):
replace = ['num_joints', 'dataset_channel']
if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)):
for item in replace:
cfg['data_cfg'][item] = cfg['data_cfg'][item][0]
return cfg
def build_dataset(cfg, default_args=None):
"""Build a dataset from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
default_args (dict, optional): Default initialization arguments.
Default: None.
Returns:
Dataset: The constructed dataset.
"""
if isinstance(cfg['type'], (list, tuple)): # In training, type=TransformerPoseDataset
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in _concat_cfg(cfg)])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
else:
cfg = _check_vaild(cfg)
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset