Spaces:
Sleeping
Sleeping
from typing import Optional | |
import numpy as np | |
import torch | |
from smplx import SMPLX as _SMPLX | |
from smplx import SMPLXLayer as _SMPLXLayer | |
from smplx.lbs import vertices2joints | |
from detrsmpl.core.conventions.keypoints_mapping import ( | |
convert_kps, | |
get_keypoint_num, | |
) | |
from detrsmpl.core.conventions.segmentation import body_segmentation | |
class SMPLX(_SMPLX): | |
"""Extension of the official SMPL-X implementation.""" | |
body_pose_keys = {'global_orient', 'body_pose'} | |
full_pose_keys = { | |
'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', | |
'jaw_pose', 'leye_pose', 'reye_pose' | |
} | |
NUM_VERTS = 10475 | |
NUM_FACES = 20908 | |
def __init__(self, | |
*args, | |
keypoint_src: str = 'smplx', | |
keypoint_dst: str = 'human_data', | |
keypoint_approximate: bool = False, | |
joints_regressor: str = None, | |
extra_joints_regressor: str = None, | |
**kwargs): | |
""" | |
Args: | |
*args: extra arguments for SMPL initialization. | |
keypoint_src: source convention of keypoints. This convention | |
is used for keypoints obtained from joint regressors. | |
Keypoints then undergo conversion into keypoint_dst | |
convention. | |
keypoint_dst: destination convention of keypoints. This convention | |
is used for keypoints in the output. | |
keypoint_approximate: whether to use approximate matching in | |
convention conversion for keypoints. | |
joints_regressor: path to joint regressor. Should be a .npy | |
file. If provided, replaces the official J_regressor of SMPL. | |
extra_joints_regressor: path to extra joint regressor. Should be | |
a .npy file. If provided, extra joints are regressed and | |
concatenated after the joints regressed with the official | |
J_regressor or joints_regressor. | |
**kwargs: extra keyword arguments for SMPL initialization. | |
Returns: | |
None | |
""" | |
super(SMPLX, self).__init__(*args, **kwargs) | |
# joints = [JOINT_MAP[i] for i in JOINT_NAMES] | |
self.keypoint_src = keypoint_src | |
self.keypoint_dst = keypoint_dst | |
self.keypoint_approximate = keypoint_approximate | |
# override the default SMPL joint regressor if available | |
if joints_regressor is not None: | |
joints_regressor = torch.tensor(np.load(joints_regressor), | |
dtype=torch.float) | |
self.register_buffer('joints_regressor', joints_regressor) | |
# allow for extra joints to be regressed if available | |
if extra_joints_regressor is not None: | |
joints_regressor_extra = torch.tensor( | |
np.load(extra_joints_regressor), dtype=torch.float) | |
self.register_buffer('joints_regressor_extra', | |
joints_regressor_extra) | |
self.num_verts = self.get_num_verts() | |
self.num_joints = get_keypoint_num(convention=self.keypoint_dst) | |
self.body_part_segmentation = body_segmentation('smplx') | |
def forward(self, | |
*args, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
**kwargs) -> dict: | |
"""Forward function. | |
Args: | |
*args: extra arguments for SMPL | |
return_verts: whether to return vertices | |
return_full_pose: whether to return full pose parameters | |
**kwargs: extra arguments for SMPL | |
Returns: | |
output: contains output parameters and attributes | |
""" | |
kwargs['get_skin'] = True | |
smplx_output = super(SMPLX, self).forward(*args, **kwargs) | |
if not hasattr(self, 'joints_regressor'): | |
joints = smplx_output.joints | |
else: | |
joints = vertices2joints(self.joints_regressor, | |
smplx_output.vertices) | |
if hasattr(self, 'joints_regressor_extra'): | |
extra_joints = vertices2joints(self.joints_regressor_extra, | |
smplx_output.vertices) | |
joints = torch.cat([joints, extra_joints], dim=1) | |
joints, joint_mask = convert_kps(joints, | |
src=self.keypoint_src, | |
dst=self.keypoint_dst, | |
approximate=self.keypoint_approximate) | |
if isinstance(joint_mask, np.ndarray): | |
joint_mask = torch.tensor(joint_mask, | |
dtype=torch.uint8, | |
device=joints.device) | |
batch_size = joints.shape[0] | |
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) | |
output = dict(global_orient=smplx_output.global_orient, | |
body_pose=smplx_output.body_pose, | |
joints=joints, | |
joint_mask=joint_mask, | |
keypoints=torch.cat([joints, joint_mask[:, :, None]], | |
dim=-1), | |
betas=smplx_output.betas) | |
if return_verts: | |
output['vertices'] = smplx_output.vertices | |
if return_full_pose: | |
output['full_pose'] = smplx_output.full_pose | |
return output | |
def tensor2dict(cls, | |
full_pose: torch.Tensor, | |
betas: Optional[torch.Tensor] = None, | |
transl: Optional[torch.Tensor] = None, | |
expression: Optional[torch.Tensor] = None) -> dict: | |
"""Convert full pose tensor to pose dict. | |
Args: | |
full_pose (torch.Tensor): shape should be (..., 165) or | |
(..., 55, 3). All zeros for T-pose. | |
betas (Optional[torch.Tensor], optional): shape should be | |
(..., 10). The batch num should be 1 or corresponds with | |
full_pose. | |
Defaults to None. | |
transl (Optional[torch.Tensor], optional): shape should be | |
(..., 3). The batch num should be 1 or corresponds with | |
full_pose. | |
Defaults to None. | |
expression (Optional[torch.Tensor], optional): shape should | |
be (..., 10). The batch num should be 1 or corresponds with | |
full_pose. | |
Defaults to None. | |
Returns: | |
dict: dict of smplx pose containing transl & betas. | |
""" | |
NUM_BODY_JOINTS = cls.NUM_BODY_JOINTS | |
NUM_HAND_JOINTS = cls.NUM_HAND_JOINTS | |
NUM_FACE_JOINTS = cls.NUM_FACE_JOINTS | |
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS | |
full_pose = full_pose.view(-1, (NUM_JOINTS + 1), 3) | |
global_orient = full_pose[:, :1] | |
body_pose = full_pose[:, 1:NUM_BODY_JOINTS + 1] | |
jaw_pose = full_pose[:, NUM_BODY_JOINTS + 1:NUM_BODY_JOINTS + 2] | |
leye_pose = full_pose[:, NUM_BODY_JOINTS + 2:NUM_BODY_JOINTS + 3] | |
reye_pose = full_pose[:, NUM_BODY_JOINTS + 3:NUM_BODY_JOINTS + 4] | |
left_hand_pose = full_pose[:, NUM_BODY_JOINTS + 4:NUM_BODY_JOINTS + 19] | |
right_hand_pose = full_pose[:, | |
NUM_BODY_JOINTS + 19:NUM_BODY_JOINTS + 34] | |
batch_size = body_pose.shape[0] | |
if betas is not None: | |
# squeeze or unsqueeze betas to 2 dims | |
betas = betas.view(-1, betas.shape[-1]) | |
if betas.shape[0] == 1: | |
betas = betas.repeat(batch_size, 1) | |
else: | |
betas = betas | |
transl = transl.view(batch_size, -1) if transl is not None else transl | |
expression = expression.view( | |
batch_size, -1) if expression is not None else torch.zeros( | |
batch_size, 10).to(body_pose.device) | |
return { | |
'betas': | |
betas, | |
'global_orient': | |
global_orient.view(batch_size, 3), | |
'body_pose': | |
body_pose.view(batch_size, NUM_BODY_JOINTS * 3), | |
'left_hand_pose': | |
left_hand_pose.view(batch_size, NUM_HAND_JOINTS * 3), | |
'right_hand_pose': | |
right_hand_pose.view(batch_size, NUM_HAND_JOINTS * 3), | |
'transl': | |
transl, | |
'expression': | |
expression, | |
'jaw_pose': | |
jaw_pose.view(batch_size, 3), | |
'leye_pose': | |
leye_pose.view(batch_size, 3), | |
'reye_pose': | |
reye_pose.view(batch_size, 3), | |
} | |
def dict2tensor(cls, smplx_dict: dict) -> torch.Tensor: | |
"""Convert smplx pose dict to full pose tensor. | |
Args: | |
smplx_dict (dict): smplx pose dict. | |
Returns: | |
torch: full pose tensor. | |
""" | |
assert cls.body_pose_keys.issubset(smplx_dict) | |
for k in smplx_dict: | |
if isinstance(smplx_dict[k], np.ndarray): | |
smplx_dict[k] = torch.Tensor(smplx_dict[k]) | |
NUM_BODY_JOINTS = cls.NUM_BODY_JOINTS | |
NUM_HAND_JOINTS = cls.NUM_HAND_JOINTS | |
NUM_FACE_JOINTS = cls.NUM_FACE_JOINTS | |
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS | |
global_orient = smplx_dict['global_orient'].reshape(-1, 1, 3) | |
body_pose = smplx_dict['body_pose'].reshape(-1, NUM_BODY_JOINTS, 3) | |
batch_size = global_orient.shape[0] | |
jaw_pose = smplx_dict.get('jaw_pose', torch.zeros((batch_size, 1, 3))) | |
leye_pose = smplx_dict.get('leye_pose', torch.zeros( | |
(batch_size, 1, 3))) | |
reye_pose = smplx_dict.get('reye_pose', torch.zeros( | |
(batch_size, 1, 3))) | |
left_hand_pose = smplx_dict.get( | |
'left_hand_pose', torch.zeros((batch_size, NUM_HAND_JOINTS, 3))) | |
right_hand_pose = smplx_dict.get( | |
'right_hand_pose', torch.zeros((batch_size, NUM_HAND_JOINTS, 3))) | |
full_pose = torch.cat([ | |
global_orient, body_pose, | |
jaw_pose.reshape(-1, 1, 3), | |
leye_pose.reshape(-1, 1, 3), | |
reye_pose.reshape(-1, 1, 3), | |
left_hand_pose.reshape(-1, 15, 3), | |
right_hand_pose.reshape(-1, 15, 3) | |
], | |
dim=1).reshape(-1, (NUM_JOINTS + 1) * 3) | |
return full_pose | |
class SMPLXLayer(_SMPLXLayer): | |
"""Extension of the official SMPL-X implementation.""" | |
body_pose_keys = {'global_orient', 'body_pose'} | |
full_pose_keys = { | |
'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', | |
'jaw_pose', 'leye_pose', 'reye_pose' | |
} | |
NUM_VERTS = 10475 | |
NUM_FACES = 20908 | |
def __init__(self, | |
*args, | |
keypoint_src: str = 'smplx', | |
keypoint_dst: str = 'human_data', | |
keypoint_approximate: bool = False, | |
joints_regressor: str = None, | |
extra_joints_regressor: str = None, | |
**kwargs): | |
""" | |
Args: | |
*args: extra arguments for SMPL initialization. | |
keypoint_src: source convention of keypoints. This convention | |
is used for keypoints obtained from joint regressors. | |
Keypoints then undergo conversion into keypoint_dst | |
convention. | |
keypoint_dst: destination convention of keypoints. This convention | |
is used for keypoints in the output. | |
keypoint_approximate: whether to use approximate matching in | |
convention conversion for keypoints. | |
joints_regressor: path to joint regressor. Should be a .npy | |
file. If provided, replaces the official J_regressor of SMPL. | |
extra_joints_regressor: path to extra joint regressor. Should be | |
a .npy file. If provided, extra joints are regressed and | |
concatenated after the joints regressed with the official | |
J_regressor or joints_regressor. | |
**kwargs: extra keyword arguments for SMPL initialization. | |
Returns: | |
None | |
""" | |
super(SMPLXLayer, self).__init__(*args, **kwargs) | |
# joints = [JOINT_MAP[i] for i in JOINT_NAMES] | |
self.keypoint_src = keypoint_src | |
self.keypoint_dst = keypoint_dst | |
self.keypoint_approximate = keypoint_approximate | |
# override the default SMPL joint regressor if available | |
if joints_regressor is not None: | |
joints_regressor = torch.tensor(np.load(joints_regressor), | |
dtype=torch.float) | |
self.register_buffer('joints_regressor', joints_regressor) | |
# allow for extra joints to be regressed if available | |
if extra_joints_regressor is not None: | |
joints_regressor_extra = torch.tensor( | |
np.load(extra_joints_regressor), dtype=torch.float) | |
self.register_buffer('joints_regressor_extra', | |
joints_regressor_extra) | |
self.num_verts = self.get_num_verts() | |
self.num_joints = get_keypoint_num(convention=self.keypoint_dst) | |
self.body_part_segmentation = body_segmentation('smplx') | |
def forward(self, | |
*args, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
**kwargs) -> dict: | |
"""Forward function. | |
Args: | |
*args: extra arguments for SMPL | |
return_verts: whether to return vertices | |
return_full_pose: whether to return full pose parameters | |
**kwargs: extra arguments for SMPL | |
Returns: | |
output: contains output parameters and attributes | |
""" | |
kwargs['get_skin'] = True | |
smplx_output = super(SMPLXLayer, self).forward(*args, **kwargs) | |
if not hasattr(self, 'joints_regressor'): | |
joints = smplx_output.joints | |
else: | |
joints = vertices2joints(self.joints_regressor, | |
smplx_output.vertices) | |
if hasattr(self, 'joints_regressor_extra'): | |
extra_joints = vertices2joints(self.joints_regressor_extra, | |
smplx_output.vertices) | |
joints = torch.cat([joints, extra_joints], dim=1) | |
joints, joint_mask = convert_kps(joints, | |
src=self.keypoint_src, | |
dst=self.keypoint_dst, | |
approximate=self.keypoint_approximate) | |
if isinstance(joint_mask, np.ndarray): | |
joint_mask = torch.tensor(joint_mask, | |
dtype=torch.uint8, | |
device=joints.device) | |
batch_size = joints.shape[0] | |
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) | |
output = dict(global_orient=smplx_output.global_orient, | |
body_pose=smplx_output.body_pose, | |
joints=joints, | |
joint_mask=joint_mask, | |
keypoints=torch.cat([joints, joint_mask[:, :, None]], | |
dim=-1), | |
betas=smplx_output.betas) | |
if return_verts: | |
output['vertices'] = smplx_output.vertices | |
if return_full_pose: | |
output['full_pose'] = smplx_output.full_pose | |
return output | |