|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from torch import Tensor |
|
from einops import rearrange |
|
|
|
from mGPT.utils.easyconvert import rep_to_rep, nfeats_of, to_matrix |
|
import mGPT.utils.geometry_tools as geometry_tools |
|
|
|
from .base import Rots2Rfeats |
|
|
|
|
|
class Globalvelandy(Rots2Rfeats): |
|
def __init__(self, |
|
path: Optional[str] = None, |
|
normalization: bool = False, |
|
pose_rep: str = "rot6d", |
|
canonicalize: bool = False, |
|
offset: bool = True, |
|
**kwargs) -> None: |
|
super().__init__(path=path, normalization=normalization) |
|
|
|
self.canonicalize = canonicalize |
|
self.pose_rep = pose_rep |
|
self.nfeats = nfeats_of(pose_rep) |
|
self.offset = offset |
|
|
|
def forward(self, data, data_rep='matrix', first_frame=None) -> Tensor: |
|
|
|
poses, trans = data.rots, data.trans |
|
|
|
|
|
|
|
root_y = trans[..., 2] |
|
trajectory = trans[..., [0, 1]] |
|
|
|
|
|
vel_trajectory = torch.diff(trajectory, dim=-2) |
|
|
|
|
|
if first_frame is None: |
|
first_frame = 0 * vel_trajectory[..., [0], :] |
|
|
|
vel_trajectory = torch.cat((first_frame, vel_trajectory), dim=-2) |
|
|
|
|
|
if self.canonicalize: |
|
|
|
matrix_poses = rep_to_rep(data_rep, 'matrix', poses) |
|
global_orient = matrix_poses[..., 0, :, :] |
|
|
|
|
|
rot2d = rep_to_rep(data_rep, 'rotvec', poses[0, 0, ...]) |
|
|
|
|
|
rot2d[..., :2] = 0 |
|
|
|
if self.offset: |
|
|
|
rot2d[..., 2] += torch.pi / 2 |
|
|
|
rot2d = rep_to_rep('rotvec', 'matrix', rot2d) |
|
|
|
|
|
global_orient = torch.einsum("...kj,...kl->...jl", rot2d, |
|
global_orient) |
|
|
|
matrix_poses = torch.cat( |
|
(global_orient[..., None, :, :], matrix_poses[..., 1:, :, :]), |
|
dim=-3) |
|
|
|
poses = rep_to_rep('matrix', data_rep, matrix_poses) |
|
|
|
|
|
vel_trajectory = torch.einsum("...kj,...lk->...lj", |
|
rot2d[..., :2, :2], vel_trajectory) |
|
|
|
poses = rep_to_rep(data_rep, self.pose_rep, poses) |
|
features = torch.cat( |
|
(root_y[..., None], vel_trajectory, |
|
rearrange(poses, "... joints rot -> ... (joints rot)")), |
|
dim=-1) |
|
features = self.normalize(features) |
|
|
|
return features |
|
|
|
def extract(self, features): |
|
root_y = features[..., 0] |
|
vel_trajectory = features[..., 1:3] |
|
poses_features = features[..., 3:] |
|
poses = rearrange(poses_features, |
|
"... (joints rot) -> ... joints rot", |
|
rot=self.nfeats) |
|
return root_y, vel_trajectory, poses |
|
|
|
def inverse(self, features, last_frame=None): |
|
features = self.unnormalize(features) |
|
root_y, vel_trajectory, poses = self.extract(features) |
|
|
|
|
|
trajectory = torch.cumsum(vel_trajectory, dim=-2) |
|
if last_frame is None: |
|
pass |
|
|
|
trajectory = trajectory - trajectory[..., [0], :] |
|
|
|
|
|
trans = torch.cat([trajectory, root_y[..., None]], dim=-1) |
|
matrix_poses = rep_to_rep(self.pose_rep, 'matrix', poses) |
|
|
|
from ..smpl import RotTransDatastruct |
|
return RotTransDatastruct(rots=matrix_poses, trans=trans) |
|
|