Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from smplx import MANO as _MANO | |
from smplx import MANOLayer as _MANOLayer | |
from detrsmpl.core.conventions.keypoints_mapping import ( | |
convert_kps, | |
get_keypoint_num, | |
) | |
class MANO(_MANO): | |
"""Extension of the official MANO implementation.""" | |
full_pose_keys = {'global_orient', 'hand_pose'} | |
NUM_VERTS = 776 | |
NUM_FACES = 9976 | |
KpId2manokps = { | |
0: 0, # Wrist | |
1: 5, | |
2: 6, | |
3: 7, # Index | |
4: 9, | |
5: 10, | |
6: 11, # Middle | |
7: 17, | |
8: 18, | |
9: 19, # Pinky | |
10: 13, | |
11: 14, | |
12: 15, # Ring | |
13: 1, | |
14: 2, | |
15: 3 | |
} # Thumb | |
kpId2vertices = { | |
4: 744, # Thumb | |
8: 320, # Index | |
12: 443, # Middle | |
16: 555, # Ring | |
20: 672 # Pink | |
} | |
def __init__(self, | |
*args, | |
keypoint_src: str = 'mano', | |
keypoint_dst: str = 'human_data', | |
keypoint_approximate: bool = False, | |
**kwargs): | |
""" | |
Args: | |
*args: extra arguments for MANO initialization. | |
keypoint_src: source convention of keypoints. This convention | |
is used for keypoints obtained from joint regressors. | |
Keypoints then undergo conversion into keypoint_dst | |
convention. | |
keypoint_dst: destination convention of keypoints. This convention | |
is used for keypoints in the output. | |
keypoint_approximate: whether to use approximate matching in | |
convention conversion for keypoints. | |
**kwargs: extra keyword arguments for MANO initialization. | |
Returns: | |
None | |
""" | |
super(MANO, self).__init__(*args, **kwargs) | |
self.keypoint_src = keypoint_src | |
self.keypoint_dst = keypoint_dst | |
self.keypoint_approximate = keypoint_approximate | |
self.num_verts = self.get_num_verts() | |
self.num_faces = self.get_num_faces() | |
self.num_joints = get_keypoint_num(convention=self.keypoint_dst) | |
def forward(self, | |
*args, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
**kwargs) -> dict: | |
"""Forward function. | |
Args: | |
*args: extra arguments for MANO | |
return_verts: whether to return vertices | |
return_full_pose: whether to return full pose parameters | |
**kwargs: extra arguments for MANO | |
Returns: | |
output: contains output parameters and attributes | |
""" | |
if 'right_hand_pose' in kwargs: | |
kwargs['hand_pose'] = kwargs['right_hand_pose'] | |
mano_output = super(MANO, self).forward(*args, **kwargs) | |
joints = mano_output.joints | |
joints = self.get_keypoints_from_mesh(mano_output.vertices, joints) | |
joints, joint_mask = convert_kps(joints, | |
src=self.keypoint_src, | |
dst=self.keypoint_dst, | |
approximate=self.keypoint_approximate) | |
if isinstance(joint_mask, np.ndarray): | |
joint_mask = torch.tensor(joint_mask, | |
dtype=torch.uint8, | |
device=joints.device) | |
batch_size = joints.shape[0] | |
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) | |
output = dict( | |
global_orient=mano_output.global_orient, | |
hand_pose=mano_output.hand_pose, | |
joints=joints, | |
joint_mask=joint_mask, | |
keypoints=torch.cat([joints, joint_mask[:, :, None]], dim=-1), | |
betas=mano_output.betas, | |
) | |
if return_verts: | |
output['vertices'] = mano_output.vertices | |
if return_full_pose: | |
output['full_pose'] = mano_output.full_pose | |
return output | |
def get_keypoints_from_mesh(self, mesh_vertices, keypoints_regressed): | |
"""Assembles the full 21 keypoint set from the 16 Mano Keypoints and 5 | |
mesh vertices for the fingers.""" | |
batch_size = keypoints_regressed.shape[0] | |
keypoints = torch.zeros((batch_size, 21, 3)).cuda() | |
# fill keypoints which are regressed | |
for manoId, myId in self.KpId2manokps.items(): | |
keypoints[:, myId, :] = keypoints_regressed[:, manoId, :] | |
# get other keypoints from mesh | |
for myId, meshId in self.kpId2vertices.items(): | |
keypoints[:, myId, :] = mesh_vertices[:, meshId, :] | |
return keypoints | |
class MANOLayer(_MANOLayer): | |
"""Extension of the official MANO implementation.""" | |
full_pose_keys = {'global_orient', 'hand_pose'} | |
NUM_VERTS = 776 | |
NUM_FACES = 9976 | |
KpId2manokps = { | |
0: 0, # Wrist | |
1: 5, | |
2: 6, | |
3: 7, # Index | |
4: 9, | |
5: 10, | |
6: 11, # Middle | |
7: 17, | |
8: 18, | |
9: 19, # Pinky | |
10: 13, | |
11: 14, | |
12: 15, # Ring | |
13: 1, | |
14: 2, | |
15: 3 | |
} # Thumb | |
kpId2vertices = { | |
4: 744, # Thumb | |
8: 320, # Index | |
12: 443, # Middle | |
16: 555, # Ring | |
20: 672 # Pink | |
} | |
def __init__(self, | |
*args, | |
keypoint_src: str = 'mano', | |
keypoint_dst: str = 'human_data', | |
keypoint_approximate: bool = False, | |
**kwargs): | |
""" | |
Args: | |
*args: extra arguments for MANO initialization. | |
keypoint_src: source convention of keypoints. This convention | |
is used for keypoints obtained from joint regressors. | |
Keypoints then undergo conversion into keypoint_dst | |
convention. | |
keypoint_dst: destination convention of keypoints. This convention | |
is used for keypoints in the output. | |
keypoint_approximate: whether to use approximate matching in | |
convention conversion for keypoints. | |
**kwargs: extra keyword arguments for MANO initialization. | |
Returns: | |
None | |
""" | |
super(MANOLayer, self).__init__(*args, **kwargs) | |
self.keypoint_src = keypoint_src | |
self.keypoint_dst = keypoint_dst | |
self.keypoint_approximate = keypoint_approximate | |
self.num_verts = self.get_num_verts() | |
self.num_faces = self.get_num_faces() | |
self.num_joints = get_keypoint_num(convention=self.keypoint_dst) | |
def forward(self, | |
*args, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
**kwargs) -> dict: | |
"""Forward function. | |
Args: | |
*args: extra arguments for MANO | |
return_verts: whether to return vertices | |
return_full_pose: whether to return full pose parameters | |
**kwargs: extra arguments for MANO | |
Returns: | |
output: contains output parameters and attributes | |
""" | |
if 'right_hand_pose' in kwargs: | |
kwargs['hand_pose'] = kwargs['right_hand_pose'] | |
mano_output = super(MANOLayer, self).forward(*args, **kwargs) | |
joints = mano_output.joints | |
joints = self.get_keypoints_from_mesh(mano_output.vertices, joints) | |
joints, joint_mask = convert_kps(joints, | |
src=self.keypoint_src, | |
dst=self.keypoint_dst, | |
approximate=self.keypoint_approximate) | |
if isinstance(joint_mask, np.ndarray): | |
joint_mask = torch.tensor(joint_mask, | |
dtype=torch.uint8, | |
device=joints.device) | |
batch_size = joints.shape[0] | |
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) | |
output = dict( | |
global_orient=mano_output.global_orient, | |
hand_pose=mano_output.hand_pose, | |
joints=joints, | |
joint_mask=joint_mask, | |
keypoints=torch.cat([joints, joint_mask[:, :, None]], dim=-1), | |
betas=mano_output.betas, | |
) | |
if return_verts: | |
output['vertices'] = mano_output.vertices | |
if return_full_pose: | |
output['full_pose'] = mano_output.full_pose | |
return output | |
def get_keypoints_from_mesh(self, mesh_vertices, keypoints_regressed): | |
"""Assembles the full 21 keypoint set from the 16 Mano Keypoints and 5 | |
mesh vertices for the fingers.""" | |
batch_size = keypoints_regressed.shape[0] | |
keypoints = torch.zeros((batch_size, 21, 3)).cuda() | |
# fill keypoints which are regressed | |
for manoId, myId in self.KpId2manokps.items(): | |
keypoints[:, myId, :] = keypoints_regressed[:, manoId, :] | |
# get other keypoints from mesh | |
for myId, meshId in self.kpId2vertices.items(): | |
keypoints[:, myId, :] = mesh_vertices[:, meshId, :] | |
return keypoints | |