Spaces:
Runtime error
Runtime error
import torch | |
from mmcv.runner import build_optimizer | |
from detrsmpl.core.conventions.keypoints_mapping import ( | |
get_keypoint_idx, | |
get_keypoint_idxs_by_part, | |
) | |
from .smplify import OptimizableParameters, SMPLify | |
class SMPLifyX(SMPLify): | |
"""Re-implementation of SMPLify-X with extended features. | |
- video input | |
- 3D keypoints | |
""" | |
def __call__(self, | |
keypoints2d: torch.Tensor = None, | |
keypoints2d_conf: torch.Tensor = None, | |
keypoints3d: torch.Tensor = None, | |
keypoints3d_conf: torch.Tensor = None, | |
init_global_orient: torch.Tensor = None, | |
init_transl: torch.Tensor = None, | |
init_body_pose: torch.Tensor = None, | |
init_betas: torch.Tensor = None, | |
init_left_hand_pose: torch.Tensor = None, | |
init_right_hand_pose: torch.Tensor = None, | |
init_expression: torch.Tensor = None, | |
init_jaw_pose: torch.Tensor = None, | |
init_leye_pose: torch.Tensor = None, | |
init_reye_pose: torch.Tensor = None, | |
return_verts: bool = False, | |
return_joints: bool = False, | |
return_full_pose: bool = False, | |
return_losses: bool = False) -> dict: | |
"""Run registration. | |
Notes: | |
B: batch size | |
K: number of keypoints | |
D: body shape dimension | |
D_H: hand pose dimension | |
D_E: expression dimension | |
Provide only keypoints2d or keypoints3d, not both. | |
Args: | |
keypoints2d: 2D keypoints of shape (B, K, 2) | |
keypoints2d_conf: 2D keypoint confidence of shape (B, K) | |
keypoints3d: 3D keypoints of shape (B, K, 3). | |
keypoints3d_conf: 3D keypoint confidence of shape (B, K) | |
init_global_orient: initial global_orient of shape (B, 3) | |
init_transl: initial transl of shape (B, 3) | |
init_body_pose: initial body_pose of shape (B, 69) | |
init_betas: initial betas of shape (B, D) | |
init_left_hand_pose: initial left hand pose of shape (B, D_H) | |
init_right_hand_pose: initial right hand pose of shape (B, D_H) | |
init_expression: initial left hand pose of shape (B, D_E) | |
init_jaw_pose: initial jaw pose of shape (B, 3) | |
init_leye_pose: initial left eye pose of shape (B, 3) | |
init_reye_pose: initial right eye pose of shape (B, 3) | |
return_verts: whether to return vertices | |
return_joints: whether to return joints | |
return_full_pose: whether to return full pose | |
return_losses: whether to return loss dict | |
Returns: | |
ret: a dictionary that includes body model parameters, | |
and optional attributes such as vertices and joints | |
""" | |
assert keypoints2d is not None or keypoints3d is not None, \ | |
'Neither of 2D nor 3D keypoints are provided.' | |
assert not (keypoints2d is not None and keypoints3d is not None), \ | |
'Do not provide both 2D and 3D keypoints.' | |
batch_size = keypoints2d.shape[0] if keypoints2d is not None \ | |
else keypoints3d.shape[0] | |
global_orient = self._match_init_batch_size( | |
init_global_orient, self.body_model.global_orient, batch_size) | |
transl = self._match_init_batch_size(init_transl, | |
self.body_model.transl, | |
batch_size) | |
body_pose = self._match_init_batch_size(init_body_pose, | |
self.body_model.body_pose, | |
batch_size) | |
left_hand_pose = self._match_init_batch_size( | |
init_left_hand_pose, self.body_model.left_hand_pose, batch_size) | |
right_hand_pose = self._match_init_batch_size( | |
init_right_hand_pose, self.body_model.right_hand_pose, batch_size) | |
expression = self._match_init_batch_size(init_expression, | |
self.body_model.expression, | |
batch_size) | |
jaw_pose = self._match_init_batch_size(init_jaw_pose, | |
self.body_model.jaw_pose, | |
batch_size) | |
leye_pose = self._match_init_batch_size(init_leye_pose, | |
self.body_model.leye_pose, | |
batch_size) | |
reye_pose = self._match_init_batch_size(init_reye_pose, | |
self.body_model.reye_pose, | |
batch_size) | |
if init_betas is None and self.use_one_betas_per_video: | |
betas = torch.zeros(1, self.body_model.betas.shape[-1]).to( | |
self.device) | |
else: | |
betas = self._match_init_batch_size(init_betas, | |
self.body_model.betas, | |
batch_size) | |
for i in range(self.num_epochs): | |
for stage_idx, stage_config in enumerate(self.stage_config): | |
# print(stage_name) | |
self._optimize_stage( | |
global_orient=global_orient, | |
transl=transl, | |
body_pose=body_pose, | |
betas=betas, | |
left_hand_pose=left_hand_pose, | |
right_hand_pose=right_hand_pose, | |
expression=expression, | |
jaw_pose=jaw_pose, | |
leye_pose=leye_pose, | |
reye_pose=reye_pose, | |
keypoints2d=keypoints2d, | |
keypoints2d_conf=keypoints2d_conf, | |
keypoints3d=keypoints3d, | |
keypoints3d_conf=keypoints3d_conf, | |
**stage_config, | |
) | |
return { | |
'global_orient': global_orient, | |
'transl': transl, | |
'body_pose': body_pose, | |
'betas': betas, | |
'left_hand_pose': left_hand_pose, | |
'right_hand_pose': right_hand_pose, | |
'expression': expression, | |
'jaw_pose': jaw_pose, | |
'leye_pose': leye_pose, | |
'reye_pose': reye_pose | |
} | |
def _optimize_stage(self, | |
betas: torch.Tensor, | |
body_pose: torch.Tensor, | |
global_orient: torch.Tensor, | |
transl: torch.Tensor, | |
left_hand_pose: torch.Tensor, | |
right_hand_pose: torch.Tensor, | |
expression: torch.Tensor, | |
jaw_pose: torch.Tensor, | |
leye_pose: torch.Tensor, | |
reye_pose: torch.Tensor, | |
fit_global_orient: bool = True, | |
fit_transl: bool = True, | |
fit_body_pose: bool = True, | |
fit_betas: bool = True, | |
fit_left_hand_pose: bool = True, | |
fit_right_hand_pose: bool = True, | |
fit_expression: bool = True, | |
fit_jaw_pose: bool = True, | |
fit_leye_pose: bool = True, | |
fit_reye_pose: bool = True, | |
keypoints2d: torch.Tensor = None, | |
keypoints2d_conf: torch.Tensor = None, | |
keypoints2d_weight: float = None, | |
keypoints3d: torch.Tensor = None, | |
keypoints3d_conf: torch.Tensor = None, | |
keypoints3d_weight: float = None, | |
shape_prior_weight: float = None, | |
joint_prior_weight: float = None, | |
smooth_loss_weight: float = None, | |
pose_prior_weight: float = None, | |
pose_reg_weight: float = None, | |
limb_length_weight: float = None, | |
joint_weights: dict = {}, | |
ftol: float = 1e-4, | |
num_iter: int = 1) -> None: | |
"""Optimize a stage of body model parameters according to | |
configuration. | |
Notes: | |
B: batch size | |
K: number of keypoints | |
D: shape dimension | |
Args: | |
betas: shape (B, D) | |
body_pose: shape (B, 69) | |
global_orient: shape (B, 3) | |
transl: shape (B, 3) | |
fit_global_orient: whether to optimize global_orient | |
fit_transl: whether to optimize transl | |
fit_body_pose: whether to optimize body_pose | |
fit_betas: whether to optimize betas | |
fit_left_hand_pose: whether to optimize left hand pose | |
fit_right_hand_pose: whether to optimize right hand pose | |
fit_expression: whether to optimize expression | |
fit_jaw_pose: whether to optimize jaw pose | |
fit_leye_pose: whether to optimize left eye pose | |
fit_reye_pose: whether to optimize right eye pose | |
keypoints2d: 2D keypoints of shape (B, K, 2) | |
keypoints2d_conf: 2D keypoint confidence of shape (B, K) | |
keypoints2d_weight: weight of 2D keypoint loss | |
keypoints3d: 3D keypoints of shape (B, K, 3). | |
keypoints3d_conf: 3D keypoint confidence of shape (B, K) | |
keypoints3d_weight: weight of 3D keypoint loss | |
shape_prior_weight: weight of shape prior loss | |
joint_prior_weight: weight of joint prior loss | |
smooth_loss_weight: weight of smooth loss | |
pose_prior_weight: weight of pose prior loss | |
pose_reg_weight: weight of pose regularization loss | |
limb_length_weight: weight of limb length loss | |
joint_weights: per joint weight of shape (K, ) | |
num_iter: number of iterations | |
ftol: early stop tolerance for relative change in loss | |
Returns: | |
None | |
""" | |
parameters = OptimizableParameters() | |
parameters.set_param(fit_global_orient, global_orient) | |
parameters.set_param(fit_transl, transl) | |
parameters.set_param(fit_body_pose, body_pose) | |
parameters.set_param(fit_betas, betas) | |
parameters.set_param(fit_left_hand_pose, left_hand_pose) | |
parameters.set_param(fit_right_hand_pose, right_hand_pose) | |
parameters.set_param(fit_expression, expression) | |
parameters.set_param(fit_jaw_pose, jaw_pose) | |
parameters.set_param(fit_leye_pose, leye_pose) | |
parameters.set_param(fit_reye_pose, reye_pose) | |
optimizer = build_optimizer(parameters, self.optimizer) | |
pre_loss = None | |
for iter_idx in range(num_iter): | |
def closure(): | |
# body_pose_fixed = use_reference_spine(body_pose, | |
# init_body_pose) | |
optimizer.zero_grad() | |
betas_video = self._expand_betas(body_pose.shape[0], betas) | |
loss_dict = self.evaluate( | |
global_orient=global_orient, | |
body_pose=body_pose, | |
betas=betas_video, | |
transl=transl, | |
left_hand_pose=left_hand_pose, | |
right_hand_pose=right_hand_pose, | |
expression=expression, | |
jaw_pose=jaw_pose, | |
leye_pose=leye_pose, | |
reye_pose=reye_pose, | |
keypoints2d=keypoints2d, | |
keypoints2d_conf=keypoints2d_conf, | |
keypoints2d_weight=keypoints2d_weight, | |
keypoints3d=keypoints3d, | |
keypoints3d_conf=keypoints3d_conf, | |
keypoints3d_weight=keypoints3d_weight, | |
joint_prior_weight=joint_prior_weight, | |
shape_prior_weight=shape_prior_weight, | |
smooth_loss_weight=smooth_loss_weight, | |
pose_prior_weight=pose_prior_weight, | |
pose_reg_weight=pose_reg_weight, | |
limb_length_weight=limb_length_weight, | |
joint_weights=joint_weights) | |
loss = loss_dict['total_loss'] | |
loss.backward() | |
return loss | |
loss = optimizer.step(closure) | |
if iter_idx > 0 and pre_loss is not None and ftol > 0: | |
loss_rel_change = self._compute_relative_change( | |
pre_loss, loss.item()) | |
if loss_rel_change < ftol: | |
print(f'[ftol={ftol}] Early stop at {iter_idx} iter!') | |
break | |
pre_loss = loss.item() | |
def evaluate( | |
self, | |
betas: torch.Tensor = None, | |
body_pose: torch.Tensor = None, | |
global_orient: torch.Tensor = None, | |
transl: torch.Tensor = None, | |
left_hand_pose: torch.Tensor = None, | |
right_hand_pose: torch.Tensor = None, | |
expression: torch.Tensor = None, | |
jaw_pose: torch.Tensor = None, | |
leye_pose: torch.Tensor = None, | |
reye_pose: torch.Tensor = None, | |
keypoints2d: torch.Tensor = None, | |
keypoints2d_conf: torch.Tensor = None, | |
keypoints2d_weight: float = None, | |
keypoints3d: torch.Tensor = None, | |
keypoints3d_conf: torch.Tensor = None, | |
keypoints3d_weight: float = None, | |
shape_prior_weight: float = None, | |
joint_prior_weight: float = None, | |
smooth_loss_weight: float = None, | |
pose_prior_weight: float = None, | |
pose_reg_weight: float = None, | |
limb_length_weight: float = None, | |
joint_weights: dict = {}, | |
return_verts: bool = False, | |
return_full_pose: bool = False, | |
return_joints: bool = False, | |
reduction_override: str = None, | |
): | |
"""Evaluate fitted parameters through loss computation. This function | |
serves two purposes: 1) internally, for loss backpropagation 2) | |
externally, for fitting quality evaluation. | |
Notes: | |
B: batch size | |
K: number of keypoints | |
D: body shape dimension | |
D_H: hand pose dimension | |
D_E: expression dimension | |
Args: | |
betas: shape (B, D) | |
body_pose: shape (B, 69) | |
global_orient: shape (B, 3) | |
transl: shape (B, 3) | |
left_hand_pose: shape (B, D_H) | |
right_hand_pose: shape (B, D_H) | |
expression: shape (B, D_E) | |
jaw_pose: shape (B, 3) | |
leye_pose: shape (B, 3) | |
reye_pose: shape (B, 3) | |
keypoints2d: 2D keypoints of shape (B, K, 2) | |
keypoints2d_conf: 2D keypoint confidence of shape (B, K) | |
keypoints2d_weight: weight of 2D keypoint loss | |
keypoints3d: 3D keypoints of shape (B, K, 3). | |
keypoints3d_conf: 3D keypoint confidence of shape (B, K) | |
keypoints3d_weight: weight of 3D keypoint loss | |
shape_prior_weight: weight of shape prior loss | |
joint_prior_weight: weight of joint prior loss | |
smooth_loss_weight: weight of smooth loss | |
pose_prior_weight: weight of pose prior loss | |
pose_reg_weight: weight of pose regularization loss | |
limb_length_weight: weight of limb length loss | |
joint_weights: per joint weight of shape (K, ) | |
return_verts: whether to return vertices | |
return_joints: whether to return joints | |
return_full_pose: whether to return full pose | |
reduction_override: reduction method, e.g., 'none', 'sum', 'mean' | |
Returns: | |
ret: a dictionary that includes body model parameters, | |
and optional attributes such as vertices and joints | |
""" | |
ret = {} | |
body_model_output = self.body_model(global_orient=global_orient, | |
body_pose=body_pose, | |
betas=betas, | |
transl=transl, | |
left_hand_pose=left_hand_pose, | |
right_hand_pose=right_hand_pose, | |
expression=expression, | |
jaw_pose=jaw_pose, | |
leye_pose=leye_pose, | |
reye_pose=reye_pose, | |
return_verts=return_verts, | |
return_full_pose=return_full_pose) | |
model_joints = body_model_output['joints'] | |
model_joint_mask = body_model_output['joint_mask'] | |
loss_dict = self._compute_loss(model_joints, | |
model_joint_mask, | |
keypoints2d=keypoints2d, | |
keypoints2d_conf=keypoints2d_conf, | |
keypoints2d_weight=keypoints2d_weight, | |
keypoints3d=keypoints3d, | |
keypoints3d_conf=keypoints3d_conf, | |
keypoints3d_weight=keypoints3d_weight, | |
joint_prior_weight=joint_prior_weight, | |
shape_prior_weight=shape_prior_weight, | |
smooth_loss_weight=smooth_loss_weight, | |
pose_prior_weight=pose_prior_weight, | |
pose_reg_weight=pose_reg_weight, | |
limb_length_weight=limb_length_weight, | |
joint_weights=joint_weights, | |
reduction_override=reduction_override, | |
body_pose=body_pose, | |
betas=betas) | |
ret.update(loss_dict) | |
if return_verts: | |
ret['vertices'] = body_model_output['vertices'] | |
if return_full_pose: | |
ret['full_pose'] = body_model_output['full_pose'] | |
if return_joints: | |
ret['joints'] = model_joints | |
return ret | |
def _set_keypoint_idxs(self): | |
"""Set keypoint indices to 1) body parts to be assigned different | |
weights 2) be ignored for keypoint loss computation. | |
Returns: | |
None | |
""" | |
convention = self.body_model.keypoint_dst | |
# obtain ignore keypoint indices | |
if self.ignore_keypoints is not None: | |
self.ignore_keypoint_idxs = [] | |
for keypoint_name in self.ignore_keypoints: | |
keypoint_idx = get_keypoint_idx(keypoint_name, | |
convention=convention) | |
if keypoint_idx != -1: | |
self.ignore_keypoint_idxs.append(keypoint_idx) | |
# obtain body part keypoint indices | |
shoulder_keypoint_idxs = get_keypoint_idxs_by_part( | |
'shoulder', convention=convention) | |
hip_keypoint_idxs = get_keypoint_idxs_by_part('hip', | |
convention=convention) | |
self.shoulder_hip_keypoint_idxs = [ | |
*shoulder_keypoint_idxs, *hip_keypoint_idxs | |
] | |
# head keypoints include all facial landmarks | |
self.face_keypoint_idxs = get_keypoint_idxs_by_part( | |
'head', convention=convention) | |
left_hand_keypoint_idxs = get_keypoint_idxs_by_part( | |
'left_hand', convention=convention) | |
right_hand_keypoint_idxs = get_keypoint_idxs_by_part( | |
'right_hand', convention=convention) | |
self.hand_keypoint_idxs = [ | |
*left_hand_keypoint_idxs, *right_hand_keypoint_idxs | |
] | |
self.body_keypoint_idxs = get_keypoint_idxs_by_part( | |
'body', convention=convention) | |
def _get_weight(self, | |
use_shoulder_hip_only: bool = False, | |
body_weight: float = 1.0, | |
hand_weight: float = 1.0, | |
face_weight: float = 1.0): | |
"""Get per keypoint weight. | |
Notes: | |
K: number of keypoints | |
Args: | |
use_shoulder_hip_only: whether to use only shoulder and hip | |
keypoints for loss computation. This is useful in the | |
warming-up stage to find a reasonably good initialization. | |
body_weight: weight of body keypoints. Body part segmentation | |
definition is included in the HumanData convention. | |
hand_weight: weight of hand keypoints. | |
face_weight: weight of face keypoints. | |
Returns: | |
weight: per keypoint weight tensor of shape (K) | |
""" | |
num_keypoint = self.body_model.num_joints | |
if use_shoulder_hip_only: | |
weight = torch.zeros([num_keypoint]).to(self.device) | |
weight[self.shoulder_hip_keypoint_idxs] = 1.0 | |
else: | |
weight = torch.ones([num_keypoint]).to(self.device) | |
weight[self.body_keypoint_idxs] = \ | |
weight[self.body_keypoint_idxs] * body_weight | |
weight[self.hand_keypoint_idxs] = \ | |
weight[self.hand_keypoint_idxs] * hand_weight | |
weight[self.face_keypoint_idxs] = \ | |
weight[self.face_keypoint_idxs] * face_weight | |
if hasattr(self, 'ignore_keypoint_idxs'): | |
weight[self.ignore_keypoint_idxs] = 0.0 | |
return weight | |