File size: 4,780 Bytes
4409449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import numpy as np
import torch
from os.path import join as pjoin
from .humanml.utils.word_vectorizer import WordVectorizer
from .humanml.scripts.motion_process import (process_file, recover_from_ric)
from . import BASEDataModule
from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken, Text2MotionDatasetM2T
from .utils import humanml3d_collate
class HumanML3DDataModule(BASEDataModule):
def __init__(self, cfg, **kwargs):
super().__init__(collate_fn=humanml3d_collate)
self.cfg = cfg
self.save_hyperparameters(logger=False)
# Basic info of the dataset
cfg.DATASET.JOINT_TYPE = 'humanml3d'
self.name = "humanml3d"
self.njoints = 22
# Path to the dataset
data_root = cfg.DATASET.HUMANML3D.ROOT
self.hparams.data_root = data_root
self.hparams.text_dir = pjoin(data_root, "texts")
self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs')
# Mean and std of the dataset
self.hparams.mean = np.load(pjoin('assets/meta', "mean.npy"))
self.hparams.std = np.load(pjoin('assets/meta', "std.npy"))
# Mean and std for fair evaluation
self.hparams.mean_eval = np.load(pjoin('assets/meta', "mean_eval.npy"))
self.hparams.std_eval = np.load(pjoin('assets/meta', "std_eval.npy"))
# Length of the dataset
self.hparams.max_motion_length = cfg.DATASET.HUMANML3D.MAX_MOTION_LEN
self.hparams.min_motion_length = cfg.DATASET.HUMANML3D.MIN_MOTION_LEN
self.hparams.max_text_len = cfg.DATASET.HUMANML3D.MAX_TEXT_LEN
self.hparams.unit_length = cfg.DATASET.HUMANML3D.UNIT_LEN
# Additional parameters
self.hparams.debug = cfg.DEBUG
self.hparams.stage = cfg.TRAIN.STAGE
# Dataset switch
self.DatasetEval = Text2MotionDatasetEval
if cfg.TRAIN.STAGE == "vae":
if cfg.model.params.motion_vae.target.split('.')[-1].lower() == "vqvae":
self.hparams.win_size = 64
self.Dataset = MotionDatasetVQ
else:
self.Dataset = MotionDataset
elif 'lm' in cfg.TRAIN.STAGE:
self.hparams.code_path = cfg.DATASET.CODE_PATH
self.hparams.task_path = cfg.DATASET.TASK_PATH
self.hparams.std_text = cfg.DATASET.HUMANML3D.STD_TEXT
self.Dataset = Text2MotionDatasetCB
elif cfg.TRAIN.STAGE == "token":
self.Dataset = Text2MotionDatasetToken
self.DatasetEval = Text2MotionDatasetToken
elif cfg.TRAIN.STAGE == "m2t":
self.Dataset = Text2MotionDatasetM2T
self.DatasetEval = Text2MotionDatasetM2T
else:
self.Dataset = Text2MotionDataset
# Get additional info of the dataset
self.nfeats = 263
cfg.DATASET.NFEATS = self.nfeats
def feats2joints(self, features):
mean = torch.tensor(self.hparams.mean).to(features)
std = torch.tensor(self.hparams.std).to(features)
features = features * std + mean
return recover_from_ric(features, self.njoints)
def joints2feats(self, features):
features = process_file(features, self.njoints)[0]
return features
def normalize(self, features):
mean = torch.tensor(self.hparams.mean).to(features)
std = torch.tensor(self.hparams.std).to(features)
features = (features - mean) / std
return features
def denormalize(self, features):
mean = torch.tensor(self.hparams.mean).to(features)
std = torch.tensor(self.hparams.std).to(features)
features = features * std + mean
return features
def renorm4t2m(self, features):
# renorm to t2m norms for using t2m evaluators
ori_mean = torch.tensor(self.hparams.mean).to(features)
ori_std = torch.tensor(self.hparams.std).to(features)
eval_mean = torch.tensor(self.hparams.mean_eval).to(features)
eval_std = torch.tensor(self.hparams.std_eval).to(features)
features = features * ori_std + ori_mean
features = (features - eval_mean) / eval_std
return features
def mm_mode(self, mm_on=True):
if mm_on:
self.is_mm = True
self.name_list = self.test_dataset.name_list
self.mm_list = np.random.choice(self.name_list,
self.cfg.METRIC.MM_NUM_SAMPLES,
replace=False)
self.test_dataset.name_list = self.mm_list
else:
self.is_mm = False
self.test_dataset.name_list = self.name_list
|