File size: 1,608 Bytes
4409449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import random
import codecs as cs
import numpy as np
from torch.utils import data
from rich.progress import track
from os.path import join as pjoin
from .dataset_m import MotionDataset
from .dataset_t2m import Text2MotionDataset
class MotionDatasetVQ(Text2MotionDataset):
def __init__(
self,
data_root,
split,
mean,
std,
max_motion_length,
min_motion_length,
win_size,
unit_length=4,
fps=20,
tmpFile=True,
tiny=False,
debug=False,
**kwargs,
):
super().__init__(data_root, split, mean, std, max_motion_length,
min_motion_length, unit_length, fps, tmpFile, tiny,
debug, **kwargs)
# Filter out the motions that are too short
self.window_size = win_size
name_list = list(self.name_list)
for name in self.name_list:
motion = self.data_dict[name]["motion"]
if motion.shape[0] < self.window_size:
name_list.remove(name)
self.data_dict.pop(name)
self.name_list = name_list
def __len__(self):
return len(self.name_list)
def __getitem__(self, item):
idx = self.pointer + item
data = self.data_dict[self.name_list[idx]]
motion, length = data["motion"], data["length"]
idx = random.randint(0, motion.shape[0] - self.window_size)
motion = motion[idx:idx + self.window_size]
motion = (motion - self.mean) / self.std
return None, motion, length, None, None, None, None,
|