Spaces:
Running
on
L40S
Running
on
L40S
"""This script is based on the release codes: | |
"HybrIK: A Hybrid Analytical-Neural Inverse Kinematics Solution for 3D Human | |
Pose and Shape Estimation. CVPR 2021" | |
(https://github.com/Jeff-sjtu/HybrIK). | |
""" | |
from __future__ import absolute_import, division, print_function | |
import torch | |
from detrsmpl.utils.transforms import aa_to_rotmat | |
def batch_inverse_kinematics_transform(pose_skeleton, | |
global_orient, | |
phis, | |
rest_pose, | |
children, | |
parents, | |
dtype=torch.float32, | |
train=False, | |
leaf_thetas=None): | |
"""Applies inverse kinematics transform to joints in a batch. | |
Args: | |
pose_skeleton (torch.tensor): | |
Locations of estimated pose skeleton with shape (Bx29x3) | |
global_orient (torch.tensor|none): | |
Tensor of global rotation matrices with shape (Bx1x3x3) | |
phis (torch.tensor): | |
Rotation on bone axis parameters with shape (Bx23x2) | |
rest_pose (torch.tensor): | |
Locations of rest (Template) pose with shape (Bx29x3) | |
children (List[int]): list of indexes of kinematic children with len 29 | |
parents (List[int]): list of indexes of kinematic parents with len 29 | |
dtype (torch.dtype, optional): | |
Data type of the created tensors. Default: torch.float32 | |
train (bool): | |
Store True in train mode. Default: False | |
leaf_thetas (torch.tensor, optional): | |
Rotation matrixes for 5 leaf joints (Bx5x3x3). Default: None | |
Returns: | |
rot_mats (torch.tensor): | |
Rotation matrics of all joints with shape (Bx29x3x3) | |
rotate_rest_pose (torch.tensor): | |
Locations of rotated rest/ template pose with shape (Bx29x3) | |
""" | |
batch_size = pose_skeleton.shape[0] | |
device = pose_skeleton.device | |
rel_rest_pose = rest_pose.clone() | |
# vec_t_k = t_k - t_pa(k) | |
rel_rest_pose[:, 1:] -= rest_pose[:, parents[1:]].clone() | |
rel_rest_pose = torch.unsqueeze(rel_rest_pose, dim=-1) | |
# rotate the T pose | |
rotate_rest_pose = torch.zeros_like(rel_rest_pose) | |
# set up the root | |
rotate_rest_pose[:, 0] = rel_rest_pose[:, 0] | |
rel_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1).detach() | |
rel_pose_skeleton[:, 1:] -= rel_pose_skeleton[:, parents[1:]].clone() | |
rel_pose_skeleton[:, 0] = rel_rest_pose[:, 0] | |
# the predicted final pose | |
final_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1) | |
if train: | |
final_pose_skeleton[:, 1:] -= \ | |
final_pose_skeleton[:, parents[1:]].clone() | |
final_pose_skeleton[:, 0] = rel_rest_pose[:, 0] | |
else: | |
final_pose_skeleton += \ | |
rel_rest_pose[:, 0:1] - final_pose_skeleton[:, 0:1] | |
rel_rest_pose = rel_rest_pose | |
rel_pose_skeleton = rel_pose_skeleton | |
final_pose_skeleton = final_pose_skeleton | |
rotate_rest_pose = rotate_rest_pose | |
assert phis.dim() == 3 | |
phis = phis / (torch.norm(phis, dim=2, keepdim=True) + 1e-8) | |
if train: | |
global_orient_mat = batch_get_pelvis_orient(rel_pose_skeleton.clone(), | |
rel_rest_pose.clone(), | |
parents, children, dtype) | |
else: | |
global_orient_mat = batch_get_pelvis_orient_svd( | |
rel_pose_skeleton.clone(), rel_rest_pose.clone(), parents, | |
children, dtype) | |
rot_mat_chain = [global_orient_mat] | |
rot_mat_local = [global_orient_mat] | |
# leaf nodes rot_mats | |
if leaf_thetas is not None: | |
leaf_cnt = 0 | |
leaf_rot_mats = leaf_thetas.view([batch_size, 5, 3, 3]) | |
for i in range(1, parents.shape[0]): | |
if children[i] == -1: | |
# leaf nodes | |
if leaf_thetas is not None: | |
rot_mat = leaf_rot_mats[:, leaf_cnt, :, :] | |
leaf_cnt += 1 | |
rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[ | |
i]] + torch.matmul(rot_mat_chain[parents[i]], | |
rel_rest_pose[:, i]) | |
rot_mat_chain.append( | |
torch.matmul(rot_mat_chain[parents[i]], rot_mat)) | |
rot_mat_local.append(rot_mat) | |
elif children[i] == -3: | |
# three children | |
rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[i]] + \ | |
torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i]) | |
spine_child = [] | |
for c in range(1, parents.shape[0]): | |
if parents[c] == i and c not in spine_child: | |
spine_child.append(c) | |
# original | |
spine_child = [] | |
for c in range(1, parents.shape[0]): | |
if parents[c] == i and c not in spine_child: | |
spine_child.append(c) | |
children_final_loc = [] | |
children_rest_loc = [] | |
for c in spine_child: | |
temp = final_pose_skeleton[:, c] - rotate_rest_pose[:, i] | |
children_final_loc.append(temp) | |
children_rest_loc.append(rel_rest_pose[:, c].clone()) | |
rot_mat = batch_get_3children_orient_svd(children_final_loc, | |
children_rest_loc, | |
rot_mat_chain[parents[i]], | |
spine_child, dtype) | |
rot_mat_chain.append( | |
torch.matmul(rot_mat_chain[parents[i]], rot_mat)) | |
rot_mat_local.append(rot_mat) | |
else: | |
# Naive Hybrik | |
if train: | |
# i: the index of k-th joint | |
child_rest_loc = rel_rest_pose[:, i] | |
child_final_loc = final_pose_skeleton[:, i] | |
# q_pa(k) = q_pa^2(k) + R_pa(k)(t_pa(k) - t_pa^2(k)) | |
rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[i]] + \ | |
torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i]) | |
# Adaptive HybrIK | |
if not train: | |
# children[i]: the index of k-th joint | |
child_rest_loc = rel_rest_pose[:, children[i]] | |
child_final_loc = final_pose_skeleton[:, children[ | |
i]] - rotate_rest_pose[:, i] | |
orig_vec = rel_pose_skeleton[:, children[i]] | |
template_vec = rel_rest_pose[:, children[i]] | |
norm_t = torch.norm(template_vec, dim=1, keepdim=True) | |
orig_vec = orig_vec * norm_t / torch.norm( | |
orig_vec, dim=1, keepdim=True) | |
diff = torch.norm(child_final_loc - orig_vec, | |
dim=1, | |
keepdim=True) | |
big_diff_idx = torch.where(diff > 15 / 1000)[0] | |
child_final_loc[big_diff_idx] = orig_vec[big_diff_idx] | |
# train: vec_p_k = R_pa(k).T * (p_k - p_pa(k)) | |
# test: vec_p_k = R_pa(k).T * (p_k - q_pa(k)) | |
child_final_loc = torch.matmul( | |
rot_mat_chain[parents[i]].transpose(1, 2), child_final_loc) | |
# (B, 1, 1) | |
child_final_norm = torch.norm(child_final_loc, dim=1, keepdim=True) | |
child_rest_norm = torch.norm(child_rest_loc, dim=1, keepdim=True) | |
# vec_n | |
axis = torch.cross(child_rest_loc, child_final_loc, dim=1) | |
axis_norm = torch.norm(axis, dim=1, keepdim=True) | |
# (B, 1, 1) | |
cos = torch.sum( | |
child_rest_loc * child_final_loc, dim=1, | |
keepdim=True) / (child_rest_norm * child_final_norm + 1e-8) | |
sin = axis_norm / (child_rest_norm * child_final_norm + 1e-8) | |
# (B, 3, 1) | |
axis = axis / (axis_norm + 1e-8) | |
# Convert location revolve to rot_mat by rodrigues | |
# (B, 1, 1) | |
rx, ry, rz = torch.split(axis, 1, dim=1) | |
zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) | |
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], | |
dim=1).view((batch_size, 3, 3)) | |
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) | |
rot_mat_loc = ident + sin * K + (1 - cos) * torch.bmm(K, K) | |
# Convert spin to rot_mat | |
# (B, 3, 1) | |
spin_axis = child_rest_loc / child_rest_norm | |
# (B, 1, 1) | |
rx, ry, rz = torch.split(spin_axis, 1, dim=1) | |
zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) | |
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], | |
dim=1).view((batch_size, 3, 3)) | |
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) | |
# (B, 1, 1) | |
cos, sin = torch.split(phis[:, i - 1], 1, dim=1) | |
cos = torch.unsqueeze(cos, dim=2) | |
sin = torch.unsqueeze(sin, dim=2) | |
rot_mat_spin = ident + sin * K + (1 - cos) * torch.bmm(K, K) | |
rot_mat = torch.matmul(rot_mat_loc, rot_mat_spin) | |
rot_mat_chain.append( | |
torch.matmul(rot_mat_chain[parents[i]], rot_mat)) | |
rot_mat_local.append(rot_mat) | |
# (B, K + 1, 3, 3) | |
rot_mats = torch.stack(rot_mat_local, dim=1) | |
return rot_mats, rotate_rest_pose.squeeze(-1) | |
def batch_get_pelvis_orient_svd(rel_pose_skeleton, rel_rest_pose, parents, | |
children, dtype): | |
"""Get pelvis orientation svd for batch data. | |
Args: | |
rel_pose_skeleton (torch.tensor): | |
Locations of root-normalized pose skeleton with shape (Bx29x3) | |
rel_rest_pose (torch.tensor): | |
Locations of rest/ template pose with shape (Bx29x3) | |
parents (List[int]): list of indexes of kinematic parents with len 29 | |
children (List[int]): list of indexes of kinematic children with len 29 | |
dtype (torch.dtype, optional): | |
Data type of the created tensors, the default is torch.float32 | |
Returns: | |
rot_mat (torch.tensor): | |
Rotation matrix of pelvis with shape (Bx3x3) | |
""" | |
pelvis_child = [int(children[0])] | |
for i in range(1, parents.shape[0]): | |
if parents[i] == 0 and i not in pelvis_child: | |
pelvis_child.append(i) | |
rest_mat = [] | |
target_mat = [] | |
for child in pelvis_child: | |
rest_mat.append(rel_rest_pose[:, child].clone()) | |
target_mat.append(rel_pose_skeleton[:, child].clone()) | |
rest_mat = torch.cat(rest_mat, dim=2) | |
target_mat = torch.cat(target_mat, dim=2) | |
S = rest_mat.bmm(target_mat.transpose(1, 2)) | |
mask_zero = S.sum(dim=(1, 2)) | |
S_non_zero = S[mask_zero != 0].reshape(-1, 3, 3) | |
U, _, V = torch.svd(S_non_zero) | |
rot_mat = torch.zeros_like(S) | |
rot_mat[mask_zero == 0] = torch.eye(3, device=S.device) | |
rot_mat_non_zero = torch.bmm(V, U.transpose(1, 2)) | |
rot_mat[mask_zero != 0] = rot_mat_non_zero | |
assert torch.sum(torch.isnan(rot_mat)) == 0, ('rot_mat', rot_mat) | |
return rot_mat | |
def batch_get_pelvis_orient(rel_pose_skeleton, rel_rest_pose, parents, | |
children, dtype): | |
"""Get pelvis orientation for batch data. | |
Args: | |
rel_pose_skeleton (torch.tensor): | |
Locations of root-normalized pose skeleton with shape (Bx29x3) | |
rel_rest_pose (torch.tensor): | |
Locations of rest/ template pose with shape (Bx29x3) | |
parents (List[int]): list of indexes of kinematic parents with len 29 | |
children (List[int]): list of indexes of kinematic children with len 29 | |
dtype (torch.dtype, optional): | |
Data type of the created tensors, the default is torch.float32 | |
Returns: | |
rot_mat (torch.tensor): | |
Rotation matrix of pelvis with shape (Bx3x3) | |
""" | |
batch_size = rel_pose_skeleton.shape[0] | |
device = rel_pose_skeleton.device | |
assert children[0] == 3 | |
pelvis_child = [int(children[0])] | |
for i in range(1, parents.shape[0]): | |
if parents[i] == 0 and i not in pelvis_child: | |
pelvis_child.append(i) | |
spine_final_loc = rel_pose_skeleton[:, int(children[0])].clone() | |
spine_rest_loc = rel_rest_pose[:, int(children[0])].clone() | |
# spine_norm = torch.norm(spine_final_loc, dim=1, keepdim=True) | |
# spine_norm = spine_final_loc / (spine_norm + 1e-8) | |
# rot_mat_spine = vectors2rotmat(spine_rest_loc, spine_final_loc, dtype) | |
# (B, 1, 1) | |
vec_final_norm = torch.norm(spine_final_loc, dim=1, keepdim=True) | |
vec_rest_norm = torch.norm(spine_rest_loc, dim=1, keepdim=True) | |
spine_norm = spine_final_loc / (vec_final_norm + 1e-8) | |
# (B, 3, 1) | |
axis = torch.cross(spine_rest_loc, spine_final_loc, dim=1) | |
axis_norm = torch.norm(axis, dim=1, keepdim=True) | |
axis = axis / (axis_norm + 1e-8) | |
angle = torch.arccos( | |
torch.sum(spine_rest_loc * spine_final_loc, dim=1, keepdim=True) / | |
(vec_rest_norm * vec_final_norm + 1e-8)) | |
axis_angle = (angle * axis).squeeze() | |
# aa to rotmat | |
rot_mat_spine = aa_to_rotmat(axis_angle) | |
assert torch.sum(torch.isnan(rot_mat_spine)) == 0, ('rot_mat_spine', | |
rot_mat_spine) | |
center_final_loc = 0 | |
center_rest_loc = 0 | |
for child in pelvis_child: | |
if child == int(children[0]): | |
continue | |
center_final_loc = center_final_loc + rel_pose_skeleton[:, | |
child].clone() | |
center_rest_loc = center_rest_loc + rel_rest_pose[:, child].clone() | |
center_final_loc = center_final_loc / (len(pelvis_child) - 1) | |
center_rest_loc = center_rest_loc / (len(pelvis_child) - 1) | |
center_rest_loc = torch.matmul(rot_mat_spine, center_rest_loc) | |
center_final_loc = center_final_loc - torch.sum( | |
center_final_loc * spine_norm, dim=1, keepdim=True) * spine_norm | |
center_rest_loc = center_rest_loc - torch.sum( | |
center_rest_loc * spine_norm, dim=1, keepdim=True) * spine_norm | |
center_final_loc_norm = torch.norm(center_final_loc, dim=1, keepdim=True) | |
center_rest_loc_norm = torch.norm(center_rest_loc, dim=1, keepdim=True) | |
# (B, 3, 1) | |
axis = torch.cross(center_rest_loc, center_final_loc, dim=1) | |
axis_norm = torch.norm(axis, dim=1, keepdim=True) | |
# (B, 1, 1) | |
cos = torch.sum( | |
center_rest_loc * center_final_loc, dim=1, | |
keepdim=True) / (center_rest_loc_norm * center_final_loc_norm + 1e-8) | |
sin = axis_norm / (center_rest_loc_norm * center_final_loc_norm + 1e-8) | |
assert torch.sum(torch.isnan(cos)) == 0, ('cos', cos) | |
assert torch.sum(torch.isnan(sin)) == 0, ('sin', sin) | |
# (B, 3, 1) | |
axis = axis / (axis_norm + 1e-8) | |
# Convert location revolve to rot_mat by rodrigues | |
# (B, 1, 1) | |
rx, ry, rz = torch.split(axis, 1, dim=1) | |
zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) | |
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ | |
.view((batch_size, 3, 3)) | |
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) | |
rot_mat_center = ident + sin * K + (1 - cos) * torch.bmm(K, K) | |
rot_mat = torch.matmul(rot_mat_center, rot_mat_spine) | |
return rot_mat | |
def batch_get_3children_orient_svd(rel_pose_skeleton, rel_rest_pose, | |
rot_mat_chain_parent, children_list, dtype): | |
"""Get pelvis orientation for batch data. | |
Args: | |
rel_pose_skeleton (torch.tensor): | |
Locations of root-normalized pose skeleton with shape (Bx29x3) | |
rel_rest_pose (torch.tensor): | |
Locations of rest/ template pose with shape (Bx29x3) | |
rot_mat_chain_parents (torch.tensor): | |
parent's rotation matrix with shape (Bx3x3) | |
children (List[int]): list of indexes of kinematic children with len 29 | |
dtype (torch.dtype, optional): | |
Data type of the created tensors, the default is torch.float32 | |
Returns: | |
rot_mat (torch.tensor): | |
Child's rotation matrix with shape (Bx3x3) | |
""" | |
rest_mat = [] | |
target_mat = [] | |
for c, child in enumerate(children_list): | |
if isinstance(rel_pose_skeleton, list): | |
target = rel_pose_skeleton[c].clone() | |
template = rel_rest_pose[c].clone() | |
else: | |
target = rel_pose_skeleton[:, child].clone() | |
template = rel_rest_pose[:, child].clone() | |
target = torch.matmul(rot_mat_chain_parent.transpose(1, 2), target) | |
target_mat.append(target) | |
rest_mat.append(template) | |
rest_mat = torch.cat(rest_mat, dim=2) | |
target_mat = torch.cat(target_mat, dim=2) | |
S = rest_mat.bmm(target_mat.transpose(1, 2)) | |
U, _, V = torch.svd(S) | |
rot_mat = torch.bmm(V, U.transpose(1, 2)) | |
assert torch.sum(torch.isnan(rot_mat)) == 0, ('3children rot_mat', rot_mat) | |
return rot_mat | |