youngseng's picture
Upload 187 files
da855ff
raw
history blame
1.69 kB
from torch.utils.data import DataLoader
from data_loaders.tensors import collate as all_collate
from data_loaders.tensors import t2m_collate
def get_dataset_class(name):
if name == "amass":
from .amass import AMASS
return AMASS
elif name == "uestc":
from .a2m.uestc import UESTC
return UESTC
elif name == "humanact12":
from .a2m.humanact12poses import HumanAct12Poses
return HumanAct12Poses
elif name == "humanml":
from data_loaders.humanml.data.dataset import HumanML3D
return HumanML3D
elif name == "kit":
from data_loaders.humanml.data.dataset import KIT
return KIT
else:
raise ValueError(f'Unsupported dataset name [{name}]')
def get_collate_fn(name, hml_mode='train'):
if hml_mode == 'gt':
from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
return t2m_eval_collate
if name in ["humanml", "kit"]:
return t2m_collate
else:
return all_collate
def get_dataset(name, num_frames, split='train', hml_mode='train'):
DATA = get_dataset_class(name)
if name in ["humanml", "kit"]:
dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode)
else:
dataset = DATA(split=split, num_frames=num_frames)
return dataset
def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train'):
dataset = get_dataset(name, num_frames, split, hml_mode)
collate = get_collate_fn(name, hml_mode)
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True,
num_workers=8, drop_last=True, collate_fn=collate
)
return loader