Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import numpy as np | |
import torch | |
from smplx import SMPL as _SMPL | |
from smplx.lbs import batch_rigid_transform, blend_shapes, vertices2joints | |
from detrsmpl.core.conventions.keypoints_mapping import ( | |
convert_kps, | |
get_keypoint_num, | |
) | |
from detrsmpl.core.conventions.segmentation import body_segmentation | |
from detrsmpl.models.utils import batch_inverse_kinematics_transform | |
from detrsmpl.utils.transforms import quat_to_rotmat | |
class SMPL(_SMPL): | |
"""Extension of the official SMPL implementation.""" | |
body_pose_keys = { | |
'global_orient', | |
'body_pose', | |
} | |
full_pose_keys = { | |
'global_orient', | |
'body_pose', | |
} | |
NUM_VERTS = 6890 | |
NUM_FACES = 13776 | |
def __init__(self, | |
*args, | |
keypoint_src: str = 'smpl_45', | |
keypoint_dst: str = 'human_data', | |
keypoint_approximate: bool = False, | |
joints_regressor: str = None, | |
extra_joints_regressor: str = None, | |
**kwargs) -> None: | |
""" | |
Args: | |
*args: extra arguments for SMPL initialization. | |
keypoint_src: source convention of keypoints. This convention | |
is used for keypoints obtained from joint regressors. | |
Keypoints then undergo conversion into keypoint_dst | |
convention. | |
keypoint_dst: destination convention of keypoints. This convention | |
is used for keypoints in the output. | |
keypoint_approximate: whether to use approximate matching in | |
convention conversion for keypoints. | |
joints_regressor: path to joint regressor. Should be a .npy | |
file. If provided, replaces the official J_regressor of SMPL. | |
extra_joints_regressor: path to extra joint regressor. Should be | |
a .npy file. If provided, extra joints are regressed and | |
concatenated after the joints regressed with the official | |
J_regressor or joints_regressor. | |
**kwargs: extra keyword arguments for SMPL initialization. | |
Returns: | |
None | |
""" | |
super(SMPL, self).__init__(*args, **kwargs) | |
# joints = [JOINT_MAP[i] for i in JOINT_NAMES] | |
self.keypoint_src = keypoint_src | |
self.keypoint_dst = keypoint_dst | |
self.keypoint_approximate = keypoint_approximate | |
# override the default SMPL joint regressor if available | |
if joints_regressor is not None: | |
joints_regressor = torch.tensor(np.load(joints_regressor), | |
dtype=torch.float) | |
self.register_buffer('joints_regressor', joints_regressor) | |
# allow for extra joints to be regressed if available | |
if extra_joints_regressor is not None: | |
joints_regressor_extra = torch.tensor( | |
np.load(extra_joints_regressor), dtype=torch.float) | |
self.register_buffer('joints_regressor_extra', | |
joints_regressor_extra) | |
self.num_verts = self.get_num_verts() | |
self.num_joints = get_keypoint_num(convention=self.keypoint_dst) | |
self.body_part_segmentation = body_segmentation('smpl') | |
def forward(self, | |
*args, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
**kwargs) -> dict: | |
"""Forward function. | |
Args: | |
*args: extra arguments for SMPL | |
return_verts: whether to return vertices | |
return_full_pose: whether to return full pose parameters | |
**kwargs: extra arguments for SMPL | |
Returns: | |
output: contains output parameters and attributes | |
""" | |
kwargs['get_skin'] = True | |
smpl_output = super(SMPL, self).forward(*args, **kwargs) | |
if not hasattr(self, 'joints_regressor'): | |
joints = smpl_output.joints | |
else: | |
joints = vertices2joints(self.joints_regressor, | |
smpl_output.vertices) | |
if hasattr(self, 'joints_regressor_extra'): | |
extra_joints = vertices2joints(self.joints_regressor_extra, | |
smpl_output.vertices) | |
joints = torch.cat([joints, extra_joints], dim=1) | |
joints, joint_mask = convert_kps(joints, | |
src=self.keypoint_src, | |
dst=self.keypoint_dst, | |
approximate=self.keypoint_approximate) | |
if isinstance(joint_mask, np.ndarray): | |
joint_mask = torch.tensor(joint_mask, | |
dtype=torch.uint8, | |
device=joints.device) | |
batch_size = joints.shape[0] | |
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) | |
output = dict(global_orient=smpl_output.global_orient, | |
body_pose=smpl_output.body_pose, | |
joints=joints, | |
joint_mask=joint_mask, | |
keypoints=torch.cat([joints, joint_mask[:, :, None]], | |
dim=-1), | |
betas=smpl_output.betas) | |
if return_verts: | |
output['vertices'] = smpl_output.vertices | |
if return_full_pose: | |
output['full_pose'] = smpl_output.full_pose | |
return output | |
def tensor2dict(cls, | |
full_pose: torch.Tensor, | |
betas: Optional[torch.Tensor] = None, | |
transl: Optional[torch.Tensor] = None): | |
"""Convert full pose tensor to pose dict. | |
Args: | |
full_pose (torch.Tensor): shape should be (..., 165) or | |
(..., 55, 3). All zeros for T-pose. | |
betas (Optional[torch.Tensor], optional): shape should be | |
(..., 10). The batch num should be 1 or corresponds with | |
full_pose. | |
Defaults to None. | |
transl (Optional[torch.Tensor], optional): shape should be | |
(..., 3). The batch num should be 1 or corresponds with | |
full_pose. | |
Defaults to None. | |
Returns: | |
dict: dict of smpl pose containing transl & betas. | |
""" | |
full_pose = full_pose.view(-1, (cls.NUM_BODY_JOINTS + 1) * 3) | |
body_pose = full_pose[:, 3:] | |
global_orient = full_pose[:, :3] | |
batch_size = full_pose.shape[0] | |
if betas is not None: | |
# squeeze or unsqueeze betas to 2 dims | |
betas = betas.view(-1, betas.shape[-1]) | |
if betas.shape[0] == 1: | |
betas = betas.repeat(batch_size, 1) | |
else: | |
betas = betas | |
transl = transl.view(batch_size, -1) if transl is not None else transl | |
return { | |
'betas': betas, | |
'body_pose': body_pose, | |
'global_orient': global_orient, | |
'transl': transl, | |
} | |
def dict2tensor(cls, smpl_dict: dict) -> torch.Tensor: | |
"""Convert smpl pose dict to full pose tensor. | |
Args: | |
smpl_dict (dict): smpl pose dict. | |
Returns: | |
torch: full pose tensor. | |
""" | |
assert cls.body_pose_keys.issubset(smpl_dict) | |
for k in smpl_dict: | |
if isinstance(smpl_dict[k], np.ndarray): | |
smpl_dict[k] = torch.Tensor(smpl_dict[k]) | |
global_orient = smpl_dict['global_orient'].view(-1, 3) | |
body_pose = smpl_dict['body_pose'].view(-1, 3 * cls.NUM_BODY_JOINTS) | |
full_pose = torch.cat([global_orient, body_pose], dim=1) | |
return full_pose | |
class GenderedSMPL(torch.nn.Module): | |
"""A wrapper of SMPL to handle gendered inputs.""" | |
def __init__(self, | |
*args, | |
keypoint_src: str = 'smpl_45', | |
keypoint_dst: str = 'human_data', | |
keypoint_approximate: bool = False, | |
joints_regressor: str = None, | |
extra_joints_regressor: str = None, | |
**kwargs) -> None: | |
""" | |
Args: | |
*args: extra arguments for SMPL initialization. | |
keypoint_src: source convention of keypoints. This convention | |
is used for keypoints obtained from joint regressors. | |
Keypoints then undergo conversion into keypoint_dst | |
convention. | |
keypoint_dst: destination convention of keypoints. This convention | |
is used for keypoints in the output. | |
keypoint_approximate: whether to use approximate matching in | |
convention conversion for keypoints. | |
joints_regressor: path to joint regressor. Should be a .npy | |
file. If provided, replaces the official J_regressor of SMPL. | |
extra_joints_regressor: path to extra joint regressor. Should be | |
a .npy file. If provided, extra joints are regressed and | |
concatenated after the joints regressed with the official | |
J_regressor or joints_regressor. | |
**kwargs: extra keyword arguments for SMPL initialization. | |
Returns: | |
None | |
""" | |
super(GenderedSMPL, self).__init__() | |
assert 'gender' not in kwargs, \ | |
self.__class__.__name__ + \ | |
'does not need \'gender\' for initialization.' | |
self.smpl_neutral = SMPL(*args, | |
gender='neutral', | |
keypoint_src=keypoint_src, | |
keypoint_dst=keypoint_dst, | |
keypoint_approximate=keypoint_approximate, | |
joints_regressor=joints_regressor, | |
extra_joints_regressor=extra_joints_regressor, | |
**kwargs) | |
self.smpl_male = SMPL(*args, | |
gender='male', | |
keypoint_src=keypoint_src, | |
keypoint_dst=keypoint_dst, | |
keypoint_approximate=keypoint_approximate, | |
joints_regressor=joints_regressor, | |
extra_joints_regressor=extra_joints_regressor, | |
**kwargs) | |
self.smpl_female = SMPL(*args, | |
gender='female', | |
keypoint_src=keypoint_src, | |
keypoint_dst=keypoint_dst, | |
keypoint_approximate=keypoint_approximate, | |
joints_regressor=joints_regressor, | |
extra_joints_regressor=extra_joints_regressor, | |
**kwargs) | |
self.num_verts = self.smpl_neutral.num_verts | |
self.num_joints = self.smpl_neutral.num_joints | |
self.faces = self.smpl_neutral.faces | |
def forward(self, | |
*args, | |
betas: torch.Tensor = None, | |
body_pose: torch.Tensor = None, | |
global_orient: torch.Tensor = None, | |
transl: torch.Tensor = None, | |
return_verts: bool = True, | |
return_full_pose: bool = False, | |
gender: torch.Tensor = None, | |
device=None, | |
**kwargs): | |
"""Forward function. | |
Note: | |
B: batch size | |
J: number of joints of model, J = 23 (SMPL) | |
K: number of keypoints | |
Args: | |
*args: extra arguments | |
betas: Tensor([B, 10]), human body shape parameters of SMPL model. | |
body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose | |
parameters of SMPL model. It should be axis-angle vector | |
([B, J*3]) or rotation matrix ([B, J, 3, 3)]. | |
global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation | |
of human body. It should be axis-angle vector ([B, 3]) or | |
rotation matrix ([B, 1, 3, 3)]. | |
transl: Tensor([B, 3]), global translation of human body. | |
gender: Tensor([B]), gender parameters of human body. -1 for | |
neutral, 0 for male , 1 for female. | |
device: the device of the output | |
**kwargs: extra keyword arguments | |
Returns: | |
outputs (dict): Dict with mesh vertices and joints. | |
- vertices: Tensor([B, V, 3]), mesh vertices | |
- joints: Tensor([B, K, 3]), 3d keypoints regressed from | |
mesh vertices. | |
""" | |
batch_size = None | |
for attr in [betas, body_pose, global_orient, transl]: | |
if attr is not None: | |
if device is None: | |
device = attr.device | |
if batch_size is None: | |
batch_size = attr.shape[0] | |
else: | |
assert batch_size == attr.shape[0] | |
if gender is not None: | |
output = { | |
'vertices': | |
torch.zeros([batch_size, self.num_verts, 3], device=device), | |
'joints': | |
torch.zeros([batch_size, self.num_joints, 3], device=device), | |
'joint_mask': | |
torch.zeros([batch_size, self.num_joints], | |
dtype=torch.uint8, | |
device=device) | |
} | |
for body_model, gender_label in \ | |
[(self.smpl_neutral, -1), | |
(self.smpl_male, 0), | |
(self.smpl_female, 1)]: | |
gender_idxs = gender == gender_label | |
# skip if no such gender is present | |
if gender_idxs.sum() == 0: | |
continue | |
output_model = body_model( | |
betas=betas[gender_idxs] if betas is not None else None, | |
body_pose=body_pose[gender_idxs] | |
if body_pose is not None else None, | |
global_orient=global_orient[gender_idxs] | |
if global_orient is not None else None, | |
transl=transl[gender_idxs] if transl is not None else None, | |
**kwargs) | |
output['joints'][gender_idxs] = output_model['joints'] | |
# TODO: quick fix | |
if 'joint_mask' in output_model: | |
output['joint_mask'][gender_idxs] = output_model[ | |
'joint_mask'] | |
if return_verts: | |
output['vertices'][gender_idxs] = output_model['vertices'] | |
if return_full_pose: | |
output['full_pose'][gender_idxs] = output_model[ | |
'full_pose'] | |
else: | |
output = self.smpl_neutral(betas=betas, | |
body_pose=body_pose, | |
global_orient=global_orient, | |
transl=transl, | |
**kwargs) | |
return output | |
def to_tensor(array, dtype=torch.float32): | |
if 'torch.tensor' not in str(type(array)): | |
return torch.tensor(array, dtype=dtype) | |
def to_np(array, dtype=np.float32): | |
if 'scipy.sparse' in str(type(array)): | |
array = array.todense() | |
return np.array(array, dtype=dtype) | |
class HybrIKSMPL(SMPL): | |
"""Extension of the SMPL for HybrIK.""" | |
NUM_JOINTS = 23 | |
NUM_BODY_JOINTS = 23 | |
NUM_BETAS = 10 | |
JOINT_NAMES = [ | |
'pelvis', | |
'left_hip', | |
'right_hip', # 2 | |
'spine1', | |
'left_knee', | |
'right_knee', # 5 | |
'spine2', | |
'left_ankle', | |
'right_ankle', # 8 | |
'spine3', | |
'left_foot', | |
'right_foot', # 11 | |
'neck', | |
'left_collar', | |
'right_collar', # 14 | |
'jaw', # 15 | |
'left_shoulder', | |
'right_shoulder', # 17 | |
'left_elbow', | |
'right_elbow', # 19 | |
'left_wrist', | |
'right_wrist', # 21 | |
'left_thumb', | |
'right_thumb', # 23 | |
'head', | |
'left_middle', | |
'right_middle', # 26 | |
'left_bigtoe', | |
'right_bigtoe' # 28 | |
] | |
LEAF_NAMES = [ | |
'head', 'left_middle', 'right_middle', 'left_bigtoe', 'right_bigtoe' | |
] | |
root_idx_17 = 0 | |
root_idx_smpl = 0 | |
def __init__(self, *args, extra_joints_regressor=None, **kwargs): | |
""" | |
Args: | |
*args: extra arguments for SMPL initialization. | |
extra_joints_regressor: path to extra joint regressor. Should be | |
a .npy file. If provided, extra joints are regressed and | |
concatenated after the joints regressed with the official | |
J_regressor or joints_regressor. | |
**kwargs: extra keyword arguments for SMPL initialization. | |
Returns: | |
None | |
""" | |
super(HybrIKSMPL, | |
self).__init__(*args, | |
extra_joints_regressor=extra_joints_regressor, | |
create_betas=False, | |
create_global_orient=False, | |
create_body_pose=False, | |
create_transl=False, | |
**kwargs) | |
self.dtype = torch.float32 | |
self.num_joints = 29 | |
self.ROOT_IDX = self.JOINT_NAMES.index('pelvis') | |
self.LEAF_IDX = [ | |
self.JOINT_NAMES.index(name) for name in self.LEAF_NAMES | |
] | |
self.SPINE3_IDX = 9 | |
# # indices of parents for each joints | |
parents = torch.zeros(len(self.JOINT_NAMES), dtype=torch.long) | |
# extend kinematic tree | |
parents[:24] = self.parents | |
parents[24] = 15 | |
parents[25] = 22 | |
parents[26] = 23 | |
parents[27] = 10 | |
parents[28] = 11 | |
if parents.shape[0] > self.num_joints: | |
parents = parents[:24] | |
self.register_buffer('children_map', | |
self._parents_to_children(parents)) | |
self.parents = parents | |
def _parents_to_children(self, parents): | |
children = torch.ones_like(parents) * -1 | |
for i in range(self.num_joints): | |
if children[parents[i]] < 0: | |
children[parents[i]] = i | |
for i in self.LEAF_IDX: | |
if i < children.shape[0]: | |
children[i] = -1 | |
children[self.SPINE3_IDX] = -3 | |
children[0] = 3 | |
children[self.SPINE3_IDX] = self.JOINT_NAMES.index('neck') | |
return children | |
def forward(self, | |
pose_skeleton, | |
betas, | |
phis, | |
global_orient, | |
transl=None, | |
return_verts=True, | |
leaf_thetas=None): | |
"""Inverse pass for the SMPL model. | |
Args: | |
pose_skeleton: torch.tensor, optional, shape Bx(J*3) | |
It should be a tensor that contains joint locations in | |
(img, Y, Z) format. (default=None) | |
betas: torch.tensor, optional, shape Bx10 | |
It can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
phis: torch.tensor, shape Bx23x2 | |
Rotation on bone axis parameters | |
global_orient: torch.tensor, optional, shape Bx3 | |
Global Orientations. | |
transl: torch.tensor, optional, shape Bx3 | |
Global Translations. | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
leaf_thetas: torch.tensor, optional, shape Bx5x4 | |
Quaternions of 5 leaf joints. (default=None) | |
Returns | |
outputs: output dictionary. | |
""" | |
batch_size = pose_skeleton.shape[0] | |
if leaf_thetas is not None: | |
leaf_thetas = leaf_thetas.reshape(batch_size * 5, 4) | |
leaf_thetas = quat_to_rotmat(leaf_thetas) | |
batch_size = max(betas.shape[0], pose_skeleton.shape[0]) | |
device = betas.device | |
# 1. Add shape contribution | |
v_shaped = self.v_template + blend_shapes(betas, self.shapedirs) | |
# 2. Get the rest joints | |
# NxJx3 array | |
if leaf_thetas is not None: | |
rest_J = vertices2joints(self.J_regressor, v_shaped) | |
else: | |
rest_J = torch.zeros((v_shaped.shape[0], 29, 3), | |
dtype=self.dtype, | |
device=device) | |
rest_J[:, :24] = vertices2joints(self.J_regressor, v_shaped) | |
leaf_number = [411, 2445, 5905, 3216, 6617] | |
leaf_vertices = v_shaped[:, leaf_number].clone() | |
rest_J[:, 24:] = leaf_vertices | |
# 3. Get the rotation matrics | |
rot_mats, rotate_rest_pose = batch_inverse_kinematics_transform( | |
pose_skeleton, | |
global_orient, | |
phis, | |
rest_J.clone(), | |
self.children_map, | |
self.parents, | |
dtype=self.dtype, | |
train=self.training, | |
leaf_thetas=leaf_thetas) | |
test_joints = True | |
if test_joints: | |
new_joints, A = batch_rigid_transform(rot_mats, | |
rest_J[:, :24].clone(), | |
self.parents[:24], | |
dtype=self.dtype) | |
else: | |
new_joints = None | |
# assert torch.mean(torch.abs(rotate_rest_pose - new_joints)) < 1e-5 | |
# 4. Add pose blend shapes | |
# rot_mats: N x (J + 1) x 3 x 3 | |
ident = torch.eye(3, dtype=self.dtype, device=device) | |
pose_feature = (rot_mats[:, 1:] - ident).view([batch_size, -1]) | |
pose_offsets = torch.matmul(pose_feature, self.posedirs) \ | |
.view(batch_size, -1, 3) | |
v_posed = pose_offsets + v_shaped | |
# 5. Do skinning: | |
# W is N x V x (J + 1) | |
W = self.lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) | |
# (N x V x (J + 1)) x (N x (J + 1) x 16) | |
num_joints = self.J_regressor.shape[0] | |
T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ | |
.view(batch_size, -1, 4, 4) | |
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], | |
dtype=self.dtype, | |
device=device) | |
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) | |
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) | |
vertices = v_homo[:, :, :3, 0] | |
joints_from_verts = vertices2joints(self.joints_regressor_extra, | |
vertices) | |
# rot_mats = rot_mats.reshape(batch_size * 24, 3, 3) | |
if transl is not None: | |
new_joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
joints_from_verts += transl.unsqueeze(dim=1) | |
else: | |
new_joints = new_joints - \ | |
new_joints[:, self.root_idx_smpl, :].unsqueeze(1).detach() | |
joints_from_verts = joints_from_verts - \ | |
joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach() | |
output = { | |
'vertices': vertices, | |
'joints': new_joints, | |
'poses': rot_mats, | |
'joints_from_verts': joints_from_verts, | |
} | |
return output | |