GenMM / dataset /tracks_motion.py
wyysf's picture
Duplicate from radames/GenMM-demo
27763e5
import os
from os.path import join as pjoin
import numpy as np
import copy
import torch
import torch.nn.functional as F
from utils.transforms import quat2repr6d, quat2euler, repr6d2quat
class TracksParser():
def __init__(self, tracks_json, scale=1.0, requires_contact=False, joint_reduction=False):
assert requires_contact==False, 'contact is not implemented for tracks data yet!!!'
self.tracks_json = tracks_json
self.scale = scale
self.requires_contact = requires_contact
self.joint_reduction = joint_reduction
self.skeleton_names = []
self.rotations = []
for i, track in enumerate(self.tracks_json):
# print(i, track['name'])
self.skeleton_names.append(track['name'])
if i == 0:
assert track['type'] == 'vector'
self.position = np.array(track['values']).reshape(-1, 3) * self.scale
self.num_frames = self.position.shape[0]
else:
assert track['type'] == 'quaternion' # DEAFULT: quaternion
rotation = np.array(track['values']).reshape(-1, 4)
if rotation.shape[0] == 0:
rotation = np.zeros((self.num_frames, 4))
elif rotation.shape[0] < self.num_frames:
rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0)
elif rotation.shape[0] > self.num_frames:
rotation = rotation[:self.num_frames]
self.rotations += [rotation]
self.rotations = np.array(self.rotations, dtype=np.float32)
def to_tensor(self, repr='euler', rot_only=False):
if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
raise Exception('Unknown rotation representation')
rotations = self.get_rotation(repr=repr)
positions = self.get_position()
if rot_only:
return rotations.reshape(rotations.shape[0], -1)
if self.requires_contact:
virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)])
virtual_contact[..., 0] = self.contact_label
rotations = torch.cat([rotations, virtual_contact], dim=1)
rotations = rotations.reshape(rotations.shape[0], -1)
return torch.cat((rotations, positions), dim=-1)
def get_rotation(self, repr='quat'):
if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':
rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1)
if repr == 'repr6d':
rotations = quat2repr6d(rotations)
if repr == 'euler':
rotations = quat2euler(rotations)
return rotations
def get_position(self):
return torch.tensor(self.position, dtype=torch.float32)
class TracksMotion:
def __init__(self, tracks_json, scale=1.0, repr='repr6d', padding=False,
use_velo=True, contact=False, keep_y_pos=True, joint_reduction=False):
self.scale = scale
self.tracks = TracksParser(tracks_json, scale, requires_contact=contact, joint_reduction=joint_reduction)
self.raw_motion = self.tracks.to_tensor(repr=repr)
self.extra = {
}
self.repr = repr
if repr == 'quat':
self.n_rot = 4
elif repr == 'repr6d':
self.n_rot = 6
elif repr == 'euler':
self.n_rot = 3
self.padding = padding
self.use_velo = use_velo
self.contact = contact
self.keep_y_pos = keep_y_pos
self.joint_reduction = joint_reduction
self.raw_motion = self.raw_motion.permute(1, 0).unsqueeze_(0) # Shape = (1, n_channel, n_frames)
self.extra['global_pos'] = self.raw_motion[:, -3:, :]
if padding:
self.n_pad = self.n_rot - 3 # pad position channels
paddings = torch.zeros_like(self.raw_motion[:, :self.n_pad])
self.raw_motion = torch.cat((self.raw_motion, paddings), dim=1)
else:
self.n_pad = 0
self.raw_motion = torch.cat((self.raw_motion[:, :-3-self.n_pad], self.raw_motion[:, -3-self.n_pad:]), dim=1)
if self.use_velo:
self.msk = [-3, -2, -1] if not keep_y_pos else [-3, -1]
self.raw_motion = self.pos2velo(self.raw_motion)
self.n_contact = len(self.tracks.skeleton.contact_id) if contact else 0
@property
def n_channels(self):
return self.raw_motion.shape[1]
def __len__(self):
return self.raw_motion.shape[-1]
def pos2velo(self, pos):
msk = [i - self.n_pad for i in self.msk]
velo = pos.detach().clone().to(pos.device)
velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1]
self.begin_pos = pos[:, msk, 0].clone()
velo[:, msk, 0] = pos[:, msk, 1]
return velo
def velo2pos(self, velo):
msk = [i - self.n_pad for i in self.msk]
pos = velo.detach().clone().to(velo.device)
pos[:, msk, 0] = self.begin_pos.to(velo.device)
pos[:, msk] = torch.cumsum(velo[:, msk], dim=-1)
return pos
def motion2pos(self, motion):
if not self.use_velo:
return motion
else:
self.velo2pos(motion.clone())
def sample(self, size=None, slerp=False, align_corners=False):
if size is None:
return {'motion': self.raw_motion, 'extra': self.extra}
else:
if slerp:
raise NotImplementedError('slerp is not not implemented yet!!!')
else:
motion = F.interpolate(self.raw_motion, size=size, mode='linear', align_corners=align_corners)
extra = {}
if 'global_pos' in self.extra.keys():
extra['global_pos'] = F.interpolate(self.extra['global_pos'], size=size, mode='linear', align_corners=align_corners)
return motion
# return {'motion': motion, 'extra': extra}
def parse(self, motion, keep_velo=False,):
"""
No batch support here!!!
:returns tracks_json
"""
motion = motion.clone()
if self.use_velo and not keep_velo:
motion = self.velo2pos(motion)
if self.n_pad:
motion = motion[:, :-self.n_pad]
if self.contact:
raise NotImplementedError('contact is not implemented yet!!!')
motion = motion.squeeze().permute(1, 0)
pos = motion[..., -3:] / self.scale
rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)
if self.repr == 'repr6d':
rot = repr6d2quat(rot)
elif self.repr == 'euler':
raise NotImplementedError('parse "euler is not implemented yet!!!')
times = []
out_tracks_json = copy.deepcopy(self.tracks.tracks_json)
for i, _track in enumerate(out_tracks_json):
if i == 0:
times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])]
out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist()
else:
out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist()
out_tracks_json[i]['times'] = times
return out_tracks_json