"""This script is based on the release codes: "HybrIK: A Hybrid Analytical-Neural Inverse Kinematics Solution for 3D Human Pose and Shape Estimation. CVPR 2021" (https://github.com/Jeff-sjtu/HybrIK). """ from __future__ import absolute_import, division, print_function import torch from detrsmpl.utils.transforms import aa_to_rotmat def batch_inverse_kinematics_transform(pose_skeleton, global_orient, phis, rest_pose, children, parents, dtype=torch.float32, train=False, leaf_thetas=None): """Applies inverse kinematics transform to joints in a batch. Args: pose_skeleton (torch.tensor): Locations of estimated pose skeleton with shape (Bx29x3) global_orient (torch.tensor|none): Tensor of global rotation matrices with shape (Bx1x3x3) phis (torch.tensor): Rotation on bone axis parameters with shape (Bx23x2) rest_pose (torch.tensor): Locations of rest (Template) pose with shape (Bx29x3) children (List[int]): list of indexes of kinematic children with len 29 parents (List[int]): list of indexes of kinematic parents with len 29 dtype (torch.dtype, optional): Data type of the created tensors. Default: torch.float32 train (bool): Store True in train mode. Default: False leaf_thetas (torch.tensor, optional): Rotation matrixes for 5 leaf joints (Bx5x3x3). Default: None Returns: rot_mats (torch.tensor): Rotation matrics of all joints with shape (Bx29x3x3) rotate_rest_pose (torch.tensor): Locations of rotated rest/ template pose with shape (Bx29x3) """ batch_size = pose_skeleton.shape[0] device = pose_skeleton.device rel_rest_pose = rest_pose.clone() # vec_t_k = t_k - t_pa(k) rel_rest_pose[:, 1:] -= rest_pose[:, parents[1:]].clone() rel_rest_pose = torch.unsqueeze(rel_rest_pose, dim=-1) # rotate the T pose rotate_rest_pose = torch.zeros_like(rel_rest_pose) # set up the root rotate_rest_pose[:, 0] = rel_rest_pose[:, 0] rel_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1).detach() rel_pose_skeleton[:, 1:] -= rel_pose_skeleton[:, parents[1:]].clone() rel_pose_skeleton[:, 0] = rel_rest_pose[:, 0] # the predicted final pose final_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1) if train: final_pose_skeleton[:, 1:] -= \ final_pose_skeleton[:, parents[1:]].clone() final_pose_skeleton[:, 0] = rel_rest_pose[:, 0] else: final_pose_skeleton += \ rel_rest_pose[:, 0:1] - final_pose_skeleton[:, 0:1] rel_rest_pose = rel_rest_pose rel_pose_skeleton = rel_pose_skeleton final_pose_skeleton = final_pose_skeleton rotate_rest_pose = rotate_rest_pose assert phis.dim() == 3 phis = phis / (torch.norm(phis, dim=2, keepdim=True) + 1e-8) if train: global_orient_mat = batch_get_pelvis_orient(rel_pose_skeleton.clone(), rel_rest_pose.clone(), parents, children, dtype) else: global_orient_mat = batch_get_pelvis_orient_svd( rel_pose_skeleton.clone(), rel_rest_pose.clone(), parents, children, dtype) rot_mat_chain = [global_orient_mat] rot_mat_local = [global_orient_mat] # leaf nodes rot_mats if leaf_thetas is not None: leaf_cnt = 0 leaf_rot_mats = leaf_thetas.view([batch_size, 5, 3, 3]) for i in range(1, parents.shape[0]): if children[i] == -1: # leaf nodes if leaf_thetas is not None: rot_mat = leaf_rot_mats[:, leaf_cnt, :, :] leaf_cnt += 1 rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[ i]] + torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i]) rot_mat_chain.append( torch.matmul(rot_mat_chain[parents[i]], rot_mat)) rot_mat_local.append(rot_mat) elif children[i] == -3: # three children rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[i]] + \ torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i]) spine_child = [] for c in range(1, parents.shape[0]): if parents[c] == i and c not in spine_child: spine_child.append(c) # original spine_child = [] for c in range(1, parents.shape[0]): if parents[c] == i and c not in spine_child: spine_child.append(c) children_final_loc = [] children_rest_loc = [] for c in spine_child: temp = final_pose_skeleton[:, c] - rotate_rest_pose[:, i] children_final_loc.append(temp) children_rest_loc.append(rel_rest_pose[:, c].clone()) rot_mat = batch_get_3children_orient_svd(children_final_loc, children_rest_loc, rot_mat_chain[parents[i]], spine_child, dtype) rot_mat_chain.append( torch.matmul(rot_mat_chain[parents[i]], rot_mat)) rot_mat_local.append(rot_mat) else: # Naive Hybrik if train: # i: the index of k-th joint child_rest_loc = rel_rest_pose[:, i] child_final_loc = final_pose_skeleton[:, i] # q_pa(k) = q_pa^2(k) + R_pa(k)(t_pa(k) - t_pa^2(k)) rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[i]] + \ torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i]) # Adaptive HybrIK if not train: # children[i]: the index of k-th joint child_rest_loc = rel_rest_pose[:, children[i]] child_final_loc = final_pose_skeleton[:, children[ i]] - rotate_rest_pose[:, i] orig_vec = rel_pose_skeleton[:, children[i]] template_vec = rel_rest_pose[:, children[i]] norm_t = torch.norm(template_vec, dim=1, keepdim=True) orig_vec = orig_vec * norm_t / torch.norm( orig_vec, dim=1, keepdim=True) diff = torch.norm(child_final_loc - orig_vec, dim=1, keepdim=True) big_diff_idx = torch.where(diff > 15 / 1000)[0] child_final_loc[big_diff_idx] = orig_vec[big_diff_idx] # train: vec_p_k = R_pa(k).T * (p_k - p_pa(k)) # test: vec_p_k = R_pa(k).T * (p_k - q_pa(k)) child_final_loc = torch.matmul( rot_mat_chain[parents[i]].transpose(1, 2), child_final_loc) # (B, 1, 1) child_final_norm = torch.norm(child_final_loc, dim=1, keepdim=True) child_rest_norm = torch.norm(child_rest_loc, dim=1, keepdim=True) # vec_n axis = torch.cross(child_rest_loc, child_final_loc, dim=1) axis_norm = torch.norm(axis, dim=1, keepdim=True) # (B, 1, 1) cos = torch.sum( child_rest_loc * child_final_loc, dim=1, keepdim=True) / (child_rest_norm * child_final_norm + 1e-8) sin = axis_norm / (child_rest_norm * child_final_norm + 1e-8) # (B, 3, 1) axis = axis / (axis_norm + 1e-8) # Convert location revolve to rot_mat by rodrigues # (B, 1, 1) rx, ry, rz = torch.split(axis, 1, dim=1) zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) rot_mat_loc = ident + sin * K + (1 - cos) * torch.bmm(K, K) # Convert spin to rot_mat # (B, 3, 1) spin_axis = child_rest_loc / child_rest_norm # (B, 1, 1) rx, ry, rz = torch.split(spin_axis, 1, dim=1) zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) # (B, 1, 1) cos, sin = torch.split(phis[:, i - 1], 1, dim=1) cos = torch.unsqueeze(cos, dim=2) sin = torch.unsqueeze(sin, dim=2) rot_mat_spin = ident + sin * K + (1 - cos) * torch.bmm(K, K) rot_mat = torch.matmul(rot_mat_loc, rot_mat_spin) rot_mat_chain.append( torch.matmul(rot_mat_chain[parents[i]], rot_mat)) rot_mat_local.append(rot_mat) # (B, K + 1, 3, 3) rot_mats = torch.stack(rot_mat_local, dim=1) return rot_mats, rotate_rest_pose.squeeze(-1) def batch_get_pelvis_orient_svd(rel_pose_skeleton, rel_rest_pose, parents, children, dtype): """Get pelvis orientation svd for batch data. Args: rel_pose_skeleton (torch.tensor): Locations of root-normalized pose skeleton with shape (Bx29x3) rel_rest_pose (torch.tensor): Locations of rest/ template pose with shape (Bx29x3) parents (List[int]): list of indexes of kinematic parents with len 29 children (List[int]): list of indexes of kinematic children with len 29 dtype (torch.dtype, optional): Data type of the created tensors, the default is torch.float32 Returns: rot_mat (torch.tensor): Rotation matrix of pelvis with shape (Bx3x3) """ pelvis_child = [int(children[0])] for i in range(1, parents.shape[0]): if parents[i] == 0 and i not in pelvis_child: pelvis_child.append(i) rest_mat = [] target_mat = [] for child in pelvis_child: rest_mat.append(rel_rest_pose[:, child].clone()) target_mat.append(rel_pose_skeleton[:, child].clone()) rest_mat = torch.cat(rest_mat, dim=2) target_mat = torch.cat(target_mat, dim=2) S = rest_mat.bmm(target_mat.transpose(1, 2)) mask_zero = S.sum(dim=(1, 2)) S_non_zero = S[mask_zero != 0].reshape(-1, 3, 3) U, _, V = torch.svd(S_non_zero) rot_mat = torch.zeros_like(S) rot_mat[mask_zero == 0] = torch.eye(3, device=S.device) rot_mat_non_zero = torch.bmm(V, U.transpose(1, 2)) rot_mat[mask_zero != 0] = rot_mat_non_zero assert torch.sum(torch.isnan(rot_mat)) == 0, ('rot_mat', rot_mat) return rot_mat def batch_get_pelvis_orient(rel_pose_skeleton, rel_rest_pose, parents, children, dtype): """Get pelvis orientation for batch data. Args: rel_pose_skeleton (torch.tensor): Locations of root-normalized pose skeleton with shape (Bx29x3) rel_rest_pose (torch.tensor): Locations of rest/ template pose with shape (Bx29x3) parents (List[int]): list of indexes of kinematic parents with len 29 children (List[int]): list of indexes of kinematic children with len 29 dtype (torch.dtype, optional): Data type of the created tensors, the default is torch.float32 Returns: rot_mat (torch.tensor): Rotation matrix of pelvis with shape (Bx3x3) """ batch_size = rel_pose_skeleton.shape[0] device = rel_pose_skeleton.device assert children[0] == 3 pelvis_child = [int(children[0])] for i in range(1, parents.shape[0]): if parents[i] == 0 and i not in pelvis_child: pelvis_child.append(i) spine_final_loc = rel_pose_skeleton[:, int(children[0])].clone() spine_rest_loc = rel_rest_pose[:, int(children[0])].clone() # spine_norm = torch.norm(spine_final_loc, dim=1, keepdim=True) # spine_norm = spine_final_loc / (spine_norm + 1e-8) # rot_mat_spine = vectors2rotmat(spine_rest_loc, spine_final_loc, dtype) # (B, 1, 1) vec_final_norm = torch.norm(spine_final_loc, dim=1, keepdim=True) vec_rest_norm = torch.norm(spine_rest_loc, dim=1, keepdim=True) spine_norm = spine_final_loc / (vec_final_norm + 1e-8) # (B, 3, 1) axis = torch.cross(spine_rest_loc, spine_final_loc, dim=1) axis_norm = torch.norm(axis, dim=1, keepdim=True) axis = axis / (axis_norm + 1e-8) angle = torch.arccos( torch.sum(spine_rest_loc * spine_final_loc, dim=1, keepdim=True) / (vec_rest_norm * vec_final_norm + 1e-8)) axis_angle = (angle * axis).squeeze() # aa to rotmat rot_mat_spine = aa_to_rotmat(axis_angle) assert torch.sum(torch.isnan(rot_mat_spine)) == 0, ('rot_mat_spine', rot_mat_spine) center_final_loc = 0 center_rest_loc = 0 for child in pelvis_child: if child == int(children[0]): continue center_final_loc = center_final_loc + rel_pose_skeleton[:, child].clone() center_rest_loc = center_rest_loc + rel_rest_pose[:, child].clone() center_final_loc = center_final_loc / (len(pelvis_child) - 1) center_rest_loc = center_rest_loc / (len(pelvis_child) - 1) center_rest_loc = torch.matmul(rot_mat_spine, center_rest_loc) center_final_loc = center_final_loc - torch.sum( center_final_loc * spine_norm, dim=1, keepdim=True) * spine_norm center_rest_loc = center_rest_loc - torch.sum( center_rest_loc * spine_norm, dim=1, keepdim=True) * spine_norm center_final_loc_norm = torch.norm(center_final_loc, dim=1, keepdim=True) center_rest_loc_norm = torch.norm(center_rest_loc, dim=1, keepdim=True) # (B, 3, 1) axis = torch.cross(center_rest_loc, center_final_loc, dim=1) axis_norm = torch.norm(axis, dim=1, keepdim=True) # (B, 1, 1) cos = torch.sum( center_rest_loc * center_final_loc, dim=1, keepdim=True) / (center_rest_loc_norm * center_final_loc_norm + 1e-8) sin = axis_norm / (center_rest_loc_norm * center_final_loc_norm + 1e-8) assert torch.sum(torch.isnan(cos)) == 0, ('cos', cos) assert torch.sum(torch.isnan(sin)) == 0, ('sin', sin) # (B, 3, 1) axis = axis / (axis_norm + 1e-8) # Convert location revolve to rot_mat by rodrigues # (B, 1, 1) rx, ry, rz = torch.split(axis, 1, dim=1) zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ .view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) rot_mat_center = ident + sin * K + (1 - cos) * torch.bmm(K, K) rot_mat = torch.matmul(rot_mat_center, rot_mat_spine) return rot_mat def batch_get_3children_orient_svd(rel_pose_skeleton, rel_rest_pose, rot_mat_chain_parent, children_list, dtype): """Get pelvis orientation for batch data. Args: rel_pose_skeleton (torch.tensor): Locations of root-normalized pose skeleton with shape (Bx29x3) rel_rest_pose (torch.tensor): Locations of rest/ template pose with shape (Bx29x3) rot_mat_chain_parents (torch.tensor): parent's rotation matrix with shape (Bx3x3) children (List[int]): list of indexes of kinematic children with len 29 dtype (torch.dtype, optional): Data type of the created tensors, the default is torch.float32 Returns: rot_mat (torch.tensor): Child's rotation matrix with shape (Bx3x3) """ rest_mat = [] target_mat = [] for c, child in enumerate(children_list): if isinstance(rel_pose_skeleton, list): target = rel_pose_skeleton[c].clone() template = rel_rest_pose[c].clone() else: target = rel_pose_skeleton[:, child].clone() template = rel_rest_pose[:, child].clone() target = torch.matmul(rot_mat_chain_parent.transpose(1, 2), target) target_mat.append(target) rest_mat.append(template) rest_mat = torch.cat(rest_mat, dim=2) target_mat = torch.cat(target_mat, dim=2) S = rest_mat.bmm(target_mat.transpose(1, 2)) U, _, V = torch.svd(S) rot_mat = torch.bmm(V, U.transpose(1, 2)) assert torch.sum(torch.isnan(rot_mat)) == 0, ('3children rot_mat', rot_mat) return rot_mat