|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from einops import rearrange |
|
from torch import Tensor |
|
from .tools import get_forward_direction, get_floor, gaussian_filter1d |
|
from mGPT.utils.geometry_tools import matrix_of_angles |
|
from .base import Joints2Jfeats |
|
|
|
|
|
class Rifke(Joints2Jfeats): |
|
|
|
def __init__(self, |
|
jointstype: str = "mmm", |
|
path: Optional[str] = None, |
|
normalization: bool = False, |
|
forward_filter: bool = False, |
|
**kwargs) -> None: |
|
|
|
|
|
|
|
|
|
|
|
super().__init__(path=path, normalization=normalization) |
|
self.jointstype = jointstype |
|
self.forward_filter = forward_filter |
|
|
|
def forward(self, joints: Tensor) -> Tensor: |
|
|
|
|
|
|
|
|
|
poses = joints.clone() |
|
poses[..., 1] -= get_floor(poses, jointstype=self.jointstype) |
|
|
|
translation = poses[..., 0, :].clone() |
|
|
|
root_y = translation[..., 1] |
|
|
|
|
|
trajectory = translation[..., [0, 2]] |
|
|
|
|
|
poses = poses[..., 1:, :] |
|
|
|
|
|
poses[..., [0, 2]] -= trajectory[..., None, :] |
|
|
|
|
|
vel_trajectory = torch.diff(trajectory, dim=-2) |
|
|
|
vel_trajectory = torch.cat( |
|
(0 * vel_trajectory[..., [0], :], vel_trajectory), dim=-2) |
|
|
|
|
|
forward = get_forward_direction(poses, jointstype=self.jointstype) |
|
if self.forward_filter: |
|
|
|
forward = gaussian_filter1d(forward, 2) |
|
|
|
forward = torch.nn.functional.normalize(forward, dim=-1) |
|
|
|
angles = torch.atan2(*(forward.transpose(0, -1))).transpose(0, -1) |
|
vel_angles = torch.diff(angles, dim=-1) |
|
|
|
vel_angles = torch.cat((0 * vel_angles[..., [0]], vel_angles), dim=-1) |
|
|
|
|
|
sin, cos = forward[..., 0], forward[..., 1] |
|
rotations_inv = matrix_of_angles(cos, sin, inv=True) |
|
|
|
|
|
poses_local = torch.einsum("...lj,...jk->...lk", poses[..., [0, 2]], |
|
rotations_inv) |
|
poses_local = torch.stack( |
|
(poses_local[..., 0], poses[..., 1], poses_local[..., 1]), axis=-1) |
|
|
|
|
|
poses_features = rearrange(poses_local, |
|
"... joints xyz -> ... (joints xyz)") |
|
|
|
|
|
vel_trajectory_local = torch.einsum("...j,...jk->...k", vel_trajectory, |
|
rotations_inv) |
|
|
|
|
|
features = torch.cat((root_y[..., None], poses_features, |
|
vel_angles[..., None], vel_trajectory_local), -1) |
|
|
|
|
|
features = self.normalize(features) |
|
return features |
|
|
|
def inverse(self, features: Tensor) -> Tensor: |
|
features = self.unnormalize(features) |
|
root_y, poses_features, vel_angles, vel_trajectory_local = self.extract( |
|
features) |
|
|
|
|
|
angles = torch.cumsum(vel_angles, dim=-1) |
|
|
|
angles = angles - angles[..., [0]] |
|
|
|
cos, sin = torch.cos(angles), torch.sin(angles) |
|
rotations = matrix_of_angles(cos, sin, inv=False) |
|
|
|
|
|
poses_local = rearrange(poses_features, |
|
"... (joints xyz) -> ... joints xyz", |
|
xyz=3) |
|
|
|
|
|
poses = torch.einsum("...lj,...jk->...lk", poses_local[..., [0, 2]], |
|
rotations) |
|
poses = torch.stack( |
|
(poses[..., 0], poses_local[..., 1], poses[..., 1]), axis=-1) |
|
|
|
|
|
vel_trajectory = torch.einsum("...j,...jk->...k", vel_trajectory_local, |
|
rotations) |
|
|
|
|
|
trajectory = torch.cumsum(vel_trajectory, dim=-2) |
|
|
|
trajectory = trajectory - trajectory[..., [0], :] |
|
|
|
|
|
poses = torch.cat((0 * poses[..., [0], :], poses), -2) |
|
|
|
|
|
poses[..., 0, 1] = root_y |
|
|
|
|
|
poses[..., [0, 2]] += trajectory[..., None, :] |
|
return poses |
|
|
|
def extract(self, features: Tensor): |
|
root_y = features[..., 0] |
|
poses_features = features[..., 1:-3] |
|
vel_angles = features[..., -3] |
|
vel_trajectory_local = features[..., -2:] |
|
|
|
return root_y, poses_features, vel_angles, vel_trajectory_local |
|
|