ttxskk
update
d7e58f0
raw
history blame
15.6 kB
from typing import Optional
import numpy as np
import torch
from smplx import SMPLX as _SMPLX
from smplx import SMPLXLayer as _SMPLXLayer
from smplx.lbs import vertices2joints
from detrsmpl.core.conventions.keypoints_mapping import (
convert_kps,
get_keypoint_num,
)
from detrsmpl.core.conventions.segmentation import body_segmentation
class SMPLX(_SMPLX):
"""Extension of the official SMPL-X implementation."""
body_pose_keys = {'global_orient', 'body_pose'}
full_pose_keys = {
'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose',
'jaw_pose', 'leye_pose', 'reye_pose'
}
NUM_VERTS = 10475
NUM_FACES = 20908
def __init__(self,
*args,
keypoint_src: str = 'smplx',
keypoint_dst: str = 'human_data',
keypoint_approximate: bool = False,
joints_regressor: str = None,
extra_joints_regressor: str = None,
**kwargs):
"""
Args:
*args: extra arguments for SMPL initialization.
keypoint_src: source convention of keypoints. This convention
is used for keypoints obtained from joint regressors.
Keypoints then undergo conversion into keypoint_dst
convention.
keypoint_dst: destination convention of keypoints. This convention
is used for keypoints in the output.
keypoint_approximate: whether to use approximate matching in
convention conversion for keypoints.
joints_regressor: path to joint regressor. Should be a .npy
file. If provided, replaces the official J_regressor of SMPL.
extra_joints_regressor: path to extra joint regressor. Should be
a .npy file. If provided, extra joints are regressed and
concatenated after the joints regressed with the official
J_regressor or joints_regressor.
**kwargs: extra keyword arguments for SMPL initialization.
Returns:
None
"""
super(SMPLX, self).__init__(*args, **kwargs)
# joints = [JOINT_MAP[i] for i in JOINT_NAMES]
self.keypoint_src = keypoint_src
self.keypoint_dst = keypoint_dst
self.keypoint_approximate = keypoint_approximate
# override the default SMPL joint regressor if available
if joints_regressor is not None:
joints_regressor = torch.tensor(np.load(joints_regressor),
dtype=torch.float)
self.register_buffer('joints_regressor', joints_regressor)
# allow for extra joints to be regressed if available
if extra_joints_regressor is not None:
joints_regressor_extra = torch.tensor(
np.load(extra_joints_regressor), dtype=torch.float)
self.register_buffer('joints_regressor_extra',
joints_regressor_extra)
self.num_verts = self.get_num_verts()
self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
self.body_part_segmentation = body_segmentation('smplx')
def forward(self,
*args,
return_verts: bool = True,
return_full_pose: bool = False,
**kwargs) -> dict:
"""Forward function.
Args:
*args: extra arguments for SMPL
return_verts: whether to return vertices
return_full_pose: whether to return full pose parameters
**kwargs: extra arguments for SMPL
Returns:
output: contains output parameters and attributes
"""
kwargs['get_skin'] = True
smplx_output = super(SMPLX, self).forward(*args, **kwargs)
if not hasattr(self, 'joints_regressor'):
joints = smplx_output.joints
else:
joints = vertices2joints(self.joints_regressor,
smplx_output.vertices)
if hasattr(self, 'joints_regressor_extra'):
extra_joints = vertices2joints(self.joints_regressor_extra,
smplx_output.vertices)
joints = torch.cat([joints, extra_joints], dim=1)
joints, joint_mask = convert_kps(joints,
src=self.keypoint_src,
dst=self.keypoint_dst,
approximate=self.keypoint_approximate)
if isinstance(joint_mask, np.ndarray):
joint_mask = torch.tensor(joint_mask,
dtype=torch.uint8,
device=joints.device)
batch_size = joints.shape[0]
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
output = dict(global_orient=smplx_output.global_orient,
body_pose=smplx_output.body_pose,
joints=joints,
joint_mask=joint_mask,
keypoints=torch.cat([joints, joint_mask[:, :, None]],
dim=-1),
betas=smplx_output.betas)
if return_verts:
output['vertices'] = smplx_output.vertices
if return_full_pose:
output['full_pose'] = smplx_output.full_pose
return output
@classmethod
def tensor2dict(cls,
full_pose: torch.Tensor,
betas: Optional[torch.Tensor] = None,
transl: Optional[torch.Tensor] = None,
expression: Optional[torch.Tensor] = None) -> dict:
"""Convert full pose tensor to pose dict.
Args:
full_pose (torch.Tensor): shape should be (..., 165) or
(..., 55, 3). All zeros for T-pose.
betas (Optional[torch.Tensor], optional): shape should be
(..., 10). The batch num should be 1 or corresponds with
full_pose.
Defaults to None.
transl (Optional[torch.Tensor], optional): shape should be
(..., 3). The batch num should be 1 or corresponds with
full_pose.
Defaults to None.
expression (Optional[torch.Tensor], optional): shape should
be (..., 10). The batch num should be 1 or corresponds with
full_pose.
Defaults to None.
Returns:
dict: dict of smplx pose containing transl & betas.
"""
NUM_BODY_JOINTS = cls.NUM_BODY_JOINTS
NUM_HAND_JOINTS = cls.NUM_HAND_JOINTS
NUM_FACE_JOINTS = cls.NUM_FACE_JOINTS
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS
full_pose = full_pose.view(-1, (NUM_JOINTS + 1), 3)
global_orient = full_pose[:, :1]
body_pose = full_pose[:, 1:NUM_BODY_JOINTS + 1]
jaw_pose = full_pose[:, NUM_BODY_JOINTS + 1:NUM_BODY_JOINTS + 2]
leye_pose = full_pose[:, NUM_BODY_JOINTS + 2:NUM_BODY_JOINTS + 3]
reye_pose = full_pose[:, NUM_BODY_JOINTS + 3:NUM_BODY_JOINTS + 4]
left_hand_pose = full_pose[:, NUM_BODY_JOINTS + 4:NUM_BODY_JOINTS + 19]
right_hand_pose = full_pose[:,
NUM_BODY_JOINTS + 19:NUM_BODY_JOINTS + 34]
batch_size = body_pose.shape[0]
if betas is not None:
# squeeze or unsqueeze betas to 2 dims
betas = betas.view(-1, betas.shape[-1])
if betas.shape[0] == 1:
betas = betas.repeat(batch_size, 1)
else:
betas = betas
transl = transl.view(batch_size, -1) if transl is not None else transl
expression = expression.view(
batch_size, -1) if expression is not None else torch.zeros(
batch_size, 10).to(body_pose.device)
return {
'betas':
betas,
'global_orient':
global_orient.view(batch_size, 3),
'body_pose':
body_pose.view(batch_size, NUM_BODY_JOINTS * 3),
'left_hand_pose':
left_hand_pose.view(batch_size, NUM_HAND_JOINTS * 3),
'right_hand_pose':
right_hand_pose.view(batch_size, NUM_HAND_JOINTS * 3),
'transl':
transl,
'expression':
expression,
'jaw_pose':
jaw_pose.view(batch_size, 3),
'leye_pose':
leye_pose.view(batch_size, 3),
'reye_pose':
reye_pose.view(batch_size, 3),
}
@classmethod
def dict2tensor(cls, smplx_dict: dict) -> torch.Tensor:
"""Convert smplx pose dict to full pose tensor.
Args:
smplx_dict (dict): smplx pose dict.
Returns:
torch: full pose tensor.
"""
assert cls.body_pose_keys.issubset(smplx_dict)
for k in smplx_dict:
if isinstance(smplx_dict[k], np.ndarray):
smplx_dict[k] = torch.Tensor(smplx_dict[k])
NUM_BODY_JOINTS = cls.NUM_BODY_JOINTS
NUM_HAND_JOINTS = cls.NUM_HAND_JOINTS
NUM_FACE_JOINTS = cls.NUM_FACE_JOINTS
NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS
global_orient = smplx_dict['global_orient'].reshape(-1, 1, 3)
body_pose = smplx_dict['body_pose'].reshape(-1, NUM_BODY_JOINTS, 3)
batch_size = global_orient.shape[0]
jaw_pose = smplx_dict.get('jaw_pose', torch.zeros((batch_size, 1, 3)))
leye_pose = smplx_dict.get('leye_pose', torch.zeros(
(batch_size, 1, 3)))
reye_pose = smplx_dict.get('reye_pose', torch.zeros(
(batch_size, 1, 3)))
left_hand_pose = smplx_dict.get(
'left_hand_pose', torch.zeros((batch_size, NUM_HAND_JOINTS, 3)))
right_hand_pose = smplx_dict.get(
'right_hand_pose', torch.zeros((batch_size, NUM_HAND_JOINTS, 3)))
full_pose = torch.cat([
global_orient, body_pose,
jaw_pose.reshape(-1, 1, 3),
leye_pose.reshape(-1, 1, 3),
reye_pose.reshape(-1, 1, 3),
left_hand_pose.reshape(-1, 15, 3),
right_hand_pose.reshape(-1, 15, 3)
],
dim=1).reshape(-1, (NUM_JOINTS + 1) * 3)
return full_pose
class SMPLXLayer(_SMPLXLayer):
"""Extension of the official SMPL-X implementation."""
body_pose_keys = {'global_orient', 'body_pose'}
full_pose_keys = {
'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose',
'jaw_pose', 'leye_pose', 'reye_pose'
}
NUM_VERTS = 10475
NUM_FACES = 20908
def __init__(self,
*args,
keypoint_src: str = 'smplx',
keypoint_dst: str = 'human_data',
keypoint_approximate: bool = False,
joints_regressor: str = None,
extra_joints_regressor: str = None,
**kwargs):
"""
Args:
*args: extra arguments for SMPL initialization.
keypoint_src: source convention of keypoints. This convention
is used for keypoints obtained from joint regressors.
Keypoints then undergo conversion into keypoint_dst
convention.
keypoint_dst: destination convention of keypoints. This convention
is used for keypoints in the output.
keypoint_approximate: whether to use approximate matching in
convention conversion for keypoints.
joints_regressor: path to joint regressor. Should be a .npy
file. If provided, replaces the official J_regressor of SMPL.
extra_joints_regressor: path to extra joint regressor. Should be
a .npy file. If provided, extra joints are regressed and
concatenated after the joints regressed with the official
J_regressor or joints_regressor.
**kwargs: extra keyword arguments for SMPL initialization.
Returns:
None
"""
super(SMPLXLayer, self).__init__(*args, **kwargs)
# joints = [JOINT_MAP[i] for i in JOINT_NAMES]
self.keypoint_src = keypoint_src
self.keypoint_dst = keypoint_dst
self.keypoint_approximate = keypoint_approximate
# override the default SMPL joint regressor if available
if joints_regressor is not None:
joints_regressor = torch.tensor(np.load(joints_regressor),
dtype=torch.float)
self.register_buffer('joints_regressor', joints_regressor)
# allow for extra joints to be regressed if available
if extra_joints_regressor is not None:
joints_regressor_extra = torch.tensor(
np.load(extra_joints_regressor), dtype=torch.float)
self.register_buffer('joints_regressor_extra',
joints_regressor_extra)
self.num_verts = self.get_num_verts()
self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
self.body_part_segmentation = body_segmentation('smplx')
def forward(self,
*args,
return_verts: bool = True,
return_full_pose: bool = False,
**kwargs) -> dict:
"""Forward function.
Args:
*args: extra arguments for SMPL
return_verts: whether to return vertices
return_full_pose: whether to return full pose parameters
**kwargs: extra arguments for SMPL
Returns:
output: contains output parameters and attributes
"""
kwargs['get_skin'] = True
smplx_output = super(SMPLXLayer, self).forward(*args, **kwargs)
if not hasattr(self, 'joints_regressor'):
joints = smplx_output.joints
else:
joints = vertices2joints(self.joints_regressor,
smplx_output.vertices)
if hasattr(self, 'joints_regressor_extra'):
extra_joints = vertices2joints(self.joints_regressor_extra,
smplx_output.vertices)
joints = torch.cat([joints, extra_joints], dim=1)
joints, joint_mask = convert_kps(joints,
src=self.keypoint_src,
dst=self.keypoint_dst,
approximate=self.keypoint_approximate)
if isinstance(joint_mask, np.ndarray):
joint_mask = torch.tensor(joint_mask,
dtype=torch.uint8,
device=joints.device)
batch_size = joints.shape[0]
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
output = dict(global_orient=smplx_output.global_orient,
body_pose=smplx_output.body_pose,
joints=joints,
joint_mask=joint_mask,
keypoints=torch.cat([joints, joint_mask[:, :, None]],
dim=-1),
betas=smplx_output.betas)
if return_verts:
output['vertices'] = smplx_output.vertices
if return_full_pose:
output['full_pose'] = smplx_output.full_pose
return output