Spaces:
Running
Running
import os | |
from os.path import join as pjoin | |
import numpy as np | |
import copy | |
import torch | |
import torch.nn.functional as F | |
from utils.transforms import quat2repr6d, quat2euler, repr6d2quat | |
class TracksParser(): | |
def __init__(self, tracks_json, scale=1.0, requires_contact=False, joint_reduction=False): | |
assert requires_contact==False, 'contact is not implemented for tracks data yet!!!' | |
self.tracks_json = tracks_json | |
self.scale = scale | |
self.requires_contact = requires_contact | |
self.joint_reduction = joint_reduction | |
self.skeleton_names = [] | |
self.rotations = [] | |
for i, track in enumerate(self.tracks_json): | |
# print(i, track['name']) | |
self.skeleton_names.append(track['name']) | |
if i == 0: | |
assert track['type'] == 'vector' | |
self.position = np.array(track['values']).reshape(-1, 3) * self.scale | |
self.num_frames = self.position.shape[0] | |
else: | |
assert track['type'] == 'quaternion' # DEAFULT: quaternion | |
rotation = np.array(track['values']).reshape(-1, 4) | |
if rotation.shape[0] == 0: | |
rotation = np.zeros((self.num_frames, 4)) | |
elif rotation.shape[0] < self.num_frames: | |
rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0) | |
elif rotation.shape[0] > self.num_frames: | |
rotation = rotation[:self.num_frames] | |
self.rotations += [rotation] | |
self.rotations = np.array(self.rotations, dtype=np.float32) | |
def to_tensor(self, repr='euler', rot_only=False): | |
if repr not in ['euler', 'quat', 'quaternion', 'repr6d']: | |
raise Exception('Unknown rotation representation') | |
rotations = self.get_rotation(repr=repr) | |
positions = self.get_position() | |
if rot_only: | |
return rotations.reshape(rotations.shape[0], -1) | |
if self.requires_contact: | |
virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)]) | |
virtual_contact[..., 0] = self.contact_label | |
rotations = torch.cat([rotations, virtual_contact], dim=1) | |
rotations = rotations.reshape(rotations.shape[0], -1) | |
return torch.cat((rotations, positions), dim=-1) | |
def get_rotation(self, repr='quat'): | |
if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d': | |
rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1) | |
if repr == 'repr6d': | |
rotations = quat2repr6d(rotations) | |
if repr == 'euler': | |
rotations = quat2euler(rotations) | |
return rotations | |
def get_position(self): | |
return torch.tensor(self.position, dtype=torch.float32) | |
class TracksMotion: | |
def __init__(self, tracks_json, scale=1.0, repr='repr6d', padding=False, | |
use_velo=True, contact=False, keep_y_pos=True, joint_reduction=False): | |
self.scale = scale | |
self.tracks = TracksParser(tracks_json, scale, requires_contact=contact, joint_reduction=joint_reduction) | |
self.raw_motion = self.tracks.to_tensor(repr=repr) | |
self.extra = { | |
} | |
self.repr = repr | |
if repr == 'quat': | |
self.n_rot = 4 | |
elif repr == 'repr6d': | |
self.n_rot = 6 | |
elif repr == 'euler': | |
self.n_rot = 3 | |
self.padding = padding | |
self.use_velo = use_velo | |
self.contact = contact | |
self.keep_y_pos = keep_y_pos | |
self.joint_reduction = joint_reduction | |
self.raw_motion = self.raw_motion.permute(1, 0).unsqueeze_(0) # Shape = (1, n_channel, n_frames) | |
self.extra['global_pos'] = self.raw_motion[:, -3:, :] | |
if padding: | |
self.n_pad = self.n_rot - 3 # pad position channels | |
paddings = torch.zeros_like(self.raw_motion[:, :self.n_pad]) | |
self.raw_motion = torch.cat((self.raw_motion, paddings), dim=1) | |
else: | |
self.n_pad = 0 | |
self.raw_motion = torch.cat((self.raw_motion[:, :-3-self.n_pad], self.raw_motion[:, -3-self.n_pad:]), dim=1) | |
if self.use_velo: | |
self.msk = [-3, -2, -1] if not keep_y_pos else [-3, -1] | |
self.raw_motion = self.pos2velo(self.raw_motion) | |
self.n_contact = len(self.tracks.skeleton.contact_id) if contact else 0 | |
def n_channels(self): | |
return self.raw_motion.shape[1] | |
def __len__(self): | |
return self.raw_motion.shape[-1] | |
def pos2velo(self, pos): | |
msk = [i - self.n_pad for i in self.msk] | |
velo = pos.detach().clone().to(pos.device) | |
velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1] | |
self.begin_pos = pos[:, msk, 0].clone() | |
velo[:, msk, 0] = pos[:, msk, 1] | |
return velo | |
def velo2pos(self, velo): | |
msk = [i - self.n_pad for i in self.msk] | |
pos = velo.detach().clone().to(velo.device) | |
pos[:, msk, 0] = self.begin_pos.to(velo.device) | |
pos[:, msk] = torch.cumsum(velo[:, msk], dim=-1) | |
return pos | |
def motion2pos(self, motion): | |
if not self.use_velo: | |
return motion | |
else: | |
self.velo2pos(motion.clone()) | |
def sample(self, size=None, slerp=False, align_corners=False): | |
if size is None: | |
return {'motion': self.raw_motion, 'extra': self.extra} | |
else: | |
if slerp: | |
raise NotImplementedError('slerp is not not implemented yet!!!') | |
else: | |
motion = F.interpolate(self.raw_motion, size=size, mode='linear', align_corners=align_corners) | |
extra = {} | |
if 'global_pos' in self.extra.keys(): | |
extra['global_pos'] = F.interpolate(self.extra['global_pos'], size=size, mode='linear', align_corners=align_corners) | |
return motion | |
# return {'motion': motion, 'extra': extra} | |
def parse(self, motion, keep_velo=False,): | |
""" | |
No batch support here!!! | |
:returns tracks_json | |
""" | |
motion = motion.clone() | |
if self.use_velo and not keep_velo: | |
motion = self.velo2pos(motion) | |
if self.n_pad: | |
motion = motion[:, :-self.n_pad] | |
if self.contact: | |
raise NotImplementedError('contact is not implemented yet!!!') | |
motion = motion.squeeze().permute(1, 0) | |
pos = motion[..., -3:] / self.scale | |
rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot) | |
if self.repr == 'repr6d': | |
rot = repr6d2quat(rot) | |
elif self.repr == 'euler': | |
raise NotImplementedError('parse "euler is not implemented yet!!!') | |
times = [] | |
out_tracks_json = copy.deepcopy(self.tracks.tracks_json) | |
for i, _track in enumerate(out_tracks_json): | |
if i == 0: | |
times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])] | |
out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist() | |
else: | |
out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist() | |
out_tracks_json[i]['times'] = times | |
return out_tracks_json | |