|
from torch.utils.data import DataLoader |
|
from data_loaders.tensors import collate as all_collate |
|
from data_loaders.tensors import t2m_collate |
|
|
|
def get_dataset_class(name): |
|
if name == "amass": |
|
from .amass import AMASS |
|
return AMASS |
|
elif name == "uestc": |
|
from .a2m.uestc import UESTC |
|
return UESTC |
|
elif name == "humanact12": |
|
from .a2m.humanact12poses import HumanAct12Poses |
|
return HumanAct12Poses |
|
elif name == "humanml": |
|
from data_loaders.humanml.data.dataset import HumanML3D |
|
return HumanML3D |
|
elif name == "kit": |
|
from data_loaders.humanml.data.dataset import KIT |
|
return KIT |
|
else: |
|
raise ValueError(f'Unsupported dataset name [{name}]') |
|
|
|
def get_collate_fn(name, hml_mode='train'): |
|
if hml_mode == 'gt': |
|
from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate |
|
return t2m_eval_collate |
|
if name in ["humanml", "kit"]: |
|
return t2m_collate |
|
else: |
|
return all_collate |
|
|
|
|
|
def get_dataset(name, num_frames, split='train', hml_mode='train'): |
|
DATA = get_dataset_class(name) |
|
if name in ["humanml", "kit"]: |
|
dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode) |
|
else: |
|
dataset = DATA(split=split, num_frames=num_frames) |
|
return dataset |
|
|
|
|
|
def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train'): |
|
dataset = get_dataset(name, num_frames, split, hml_mode) |
|
collate = get_collate_fn(name, hml_mode) |
|
|
|
loader = DataLoader( |
|
dataset, batch_size=batch_size, shuffle=True, |
|
num_workers=8, drop_last=True, collate_fn=collate |
|
) |
|
|
|
return loader |