import os from os.path import join as pjoin import numpy as np import copy import torch import torch.nn.functional as F from utils.transforms import quat2repr6d, quat2euler, repr6d2quat class TracksParser(): def __init__(self, tracks_json, scale=1.0, requires_contact=False, joint_reduction=False): assert requires_contact==False, 'contact is not implemented for tracks data yet!!!' self.tracks_json = tracks_json self.scale = scale self.requires_contact = requires_contact self.joint_reduction = joint_reduction self.skeleton_names = [] self.rotations = [] for i, track in enumerate(self.tracks_json): # print(i, track['name']) self.skeleton_names.append(track['name']) if i == 0: assert track['type'] == 'vector' self.position = np.array(track['values']).reshape(-1, 3) * self.scale self.num_frames = self.position.shape[0] else: assert track['type'] == 'quaternion' # DEAFULT: quaternion rotation = np.array(track['values']).reshape(-1, 4) if rotation.shape[0] == 0: rotation = np.zeros((self.num_frames, 4)) elif rotation.shape[0] < self.num_frames: rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0) elif rotation.shape[0] > self.num_frames: rotation = rotation[:self.num_frames] self.rotations += [rotation] self.rotations = np.array(self.rotations, dtype=np.float32) def to_tensor(self, repr='euler', rot_only=False): if repr not in ['euler', 'quat', 'quaternion', 'repr6d']: raise Exception('Unknown rotation representation') rotations = self.get_rotation(repr=repr) positions = self.get_position() if rot_only: return rotations.reshape(rotations.shape[0], -1) if self.requires_contact: virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)]) virtual_contact[..., 0] = self.contact_label rotations = torch.cat([rotations, virtual_contact], dim=1) rotations = rotations.reshape(rotations.shape[0], -1) return torch.cat((rotations, positions), dim=-1) def get_rotation(self, repr='quat'): if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d': rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1) if repr == 'repr6d': rotations = quat2repr6d(rotations) if repr == 'euler': rotations = quat2euler(rotations) return rotations def get_position(self): return torch.tensor(self.position, dtype=torch.float32) class TracksMotion: def __init__(self, tracks_json, scale=1.0, repr='repr6d', padding=False, use_velo=True, contact=False, keep_y_pos=True, joint_reduction=False): self.scale = scale self.tracks = TracksParser(tracks_json, scale, requires_contact=contact, joint_reduction=joint_reduction) self.raw_motion = self.tracks.to_tensor(repr=repr) self.extra = { } self.repr = repr if repr == 'quat': self.n_rot = 4 elif repr == 'repr6d': self.n_rot = 6 elif repr == 'euler': self.n_rot = 3 self.padding = padding self.use_velo = use_velo self.contact = contact self.keep_y_pos = keep_y_pos self.joint_reduction = joint_reduction self.raw_motion = self.raw_motion.permute(1, 0).unsqueeze_(0) # Shape = (1, n_channel, n_frames) self.extra['global_pos'] = self.raw_motion[:, -3:, :] if padding: self.n_pad = self.n_rot - 3 # pad position channels paddings = torch.zeros_like(self.raw_motion[:, :self.n_pad]) self.raw_motion = torch.cat((self.raw_motion, paddings), dim=1) else: self.n_pad = 0 self.raw_motion = torch.cat((self.raw_motion[:, :-3-self.n_pad], self.raw_motion[:, -3-self.n_pad:]), dim=1) if self.use_velo: self.msk = [-3, -2, -1] if not keep_y_pos else [-3, -1] self.raw_motion = self.pos2velo(self.raw_motion) self.n_contact = len(self.tracks.skeleton.contact_id) if contact else 0 @property def n_channels(self): return self.raw_motion.shape[1] def __len__(self): return self.raw_motion.shape[-1] def pos2velo(self, pos): msk = [i - self.n_pad for i in self.msk] velo = pos.detach().clone().to(pos.device) velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1] self.begin_pos = pos[:, msk, 0].clone() velo[:, msk, 0] = pos[:, msk, 1] return velo def velo2pos(self, velo): msk = [i - self.n_pad for i in self.msk] pos = velo.detach().clone().to(velo.device) pos[:, msk, 0] = self.begin_pos.to(velo.device) pos[:, msk] = torch.cumsum(velo[:, msk], dim=-1) return pos def motion2pos(self, motion): if not self.use_velo: return motion else: self.velo2pos(motion.clone()) def sample(self, size=None, slerp=False, align_corners=False): if size is None: return {'motion': self.raw_motion, 'extra': self.extra} else: if slerp: raise NotImplementedError('slerp is not not implemented yet!!!') else: motion = F.interpolate(self.raw_motion, size=size, mode='linear', align_corners=align_corners) extra = {} if 'global_pos' in self.extra.keys(): extra['global_pos'] = F.interpolate(self.extra['global_pos'], size=size, mode='linear', align_corners=align_corners) return motion # return {'motion': motion, 'extra': extra} def parse(self, motion, keep_velo=False,): """ No batch support here!!! :returns tracks_json """ motion = motion.clone() if self.use_velo and not keep_velo: motion = self.velo2pos(motion) if self.n_pad: motion = motion[:, :-self.n_pad] if self.contact: raise NotImplementedError('contact is not implemented yet!!!') motion = motion.squeeze().permute(1, 0) pos = motion[..., -3:] / self.scale rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot) if self.repr == 'repr6d': rot = repr6d2quat(rot) elif self.repr == 'euler': raise NotImplementedError('parse "euler is not implemented yet!!!') times = [] out_tracks_json = copy.deepcopy(self.tracks.tracks_json) for i, _track in enumerate(out_tracks_json): if i == 0: times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])] out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist() else: out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist() out_tracks_json[i]['times'] = times return out_tracks_json