ttxskk
update
d7e58f0
raw
history blame
21.9 kB
import torch
from mmcv.runner import build_optimizer
from detrsmpl.core.conventions.keypoints_mapping import (
get_keypoint_idx,
get_keypoint_idxs_by_part,
)
from .smplify import OptimizableParameters, SMPLify
class SMPLifyX(SMPLify):
"""Re-implementation of SMPLify-X with extended features.
- video input
- 3D keypoints
"""
def __call__(self,
keypoints2d: torch.Tensor = None,
keypoints2d_conf: torch.Tensor = None,
keypoints3d: torch.Tensor = None,
keypoints3d_conf: torch.Tensor = None,
init_global_orient: torch.Tensor = None,
init_transl: torch.Tensor = None,
init_body_pose: torch.Tensor = None,
init_betas: torch.Tensor = None,
init_left_hand_pose: torch.Tensor = None,
init_right_hand_pose: torch.Tensor = None,
init_expression: torch.Tensor = None,
init_jaw_pose: torch.Tensor = None,
init_leye_pose: torch.Tensor = None,
init_reye_pose: torch.Tensor = None,
return_verts: bool = False,
return_joints: bool = False,
return_full_pose: bool = False,
return_losses: bool = False) -> dict:
"""Run registration.
Notes:
B: batch size
K: number of keypoints
D: body shape dimension
D_H: hand pose dimension
D_E: expression dimension
Provide only keypoints2d or keypoints3d, not both.
Args:
keypoints2d: 2D keypoints of shape (B, K, 2)
keypoints2d_conf: 2D keypoint confidence of shape (B, K)
keypoints3d: 3D keypoints of shape (B, K, 3).
keypoints3d_conf: 3D keypoint confidence of shape (B, K)
init_global_orient: initial global_orient of shape (B, 3)
init_transl: initial transl of shape (B, 3)
init_body_pose: initial body_pose of shape (B, 69)
init_betas: initial betas of shape (B, D)
init_left_hand_pose: initial left hand pose of shape (B, D_H)
init_right_hand_pose: initial right hand pose of shape (B, D_H)
init_expression: initial left hand pose of shape (B, D_E)
init_jaw_pose: initial jaw pose of shape (B, 3)
init_leye_pose: initial left eye pose of shape (B, 3)
init_reye_pose: initial right eye pose of shape (B, 3)
return_verts: whether to return vertices
return_joints: whether to return joints
return_full_pose: whether to return full pose
return_losses: whether to return loss dict
Returns:
ret: a dictionary that includes body model parameters,
and optional attributes such as vertices and joints
"""
assert keypoints2d is not None or keypoints3d is not None, \
'Neither of 2D nor 3D keypoints are provided.'
assert not (keypoints2d is not None and keypoints3d is not None), \
'Do not provide both 2D and 3D keypoints.'
batch_size = keypoints2d.shape[0] if keypoints2d is not None \
else keypoints3d.shape[0]
global_orient = self._match_init_batch_size(
init_global_orient, self.body_model.global_orient, batch_size)
transl = self._match_init_batch_size(init_transl,
self.body_model.transl,
batch_size)
body_pose = self._match_init_batch_size(init_body_pose,
self.body_model.body_pose,
batch_size)
left_hand_pose = self._match_init_batch_size(
init_left_hand_pose, self.body_model.left_hand_pose, batch_size)
right_hand_pose = self._match_init_batch_size(
init_right_hand_pose, self.body_model.right_hand_pose, batch_size)
expression = self._match_init_batch_size(init_expression,
self.body_model.expression,
batch_size)
jaw_pose = self._match_init_batch_size(init_jaw_pose,
self.body_model.jaw_pose,
batch_size)
leye_pose = self._match_init_batch_size(init_leye_pose,
self.body_model.leye_pose,
batch_size)
reye_pose = self._match_init_batch_size(init_reye_pose,
self.body_model.reye_pose,
batch_size)
if init_betas is None and self.use_one_betas_per_video:
betas = torch.zeros(1, self.body_model.betas.shape[-1]).to(
self.device)
else:
betas = self._match_init_batch_size(init_betas,
self.body_model.betas,
batch_size)
for i in range(self.num_epochs):
for stage_idx, stage_config in enumerate(self.stage_config):
# print(stage_name)
self._optimize_stage(
global_orient=global_orient,
transl=transl,
body_pose=body_pose,
betas=betas,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
expression=expression,
jaw_pose=jaw_pose,
leye_pose=leye_pose,
reye_pose=reye_pose,
keypoints2d=keypoints2d,
keypoints2d_conf=keypoints2d_conf,
keypoints3d=keypoints3d,
keypoints3d_conf=keypoints3d_conf,
**stage_config,
)
return {
'global_orient': global_orient,
'transl': transl,
'body_pose': body_pose,
'betas': betas,
'left_hand_pose': left_hand_pose,
'right_hand_pose': right_hand_pose,
'expression': expression,
'jaw_pose': jaw_pose,
'leye_pose': leye_pose,
'reye_pose': reye_pose
}
def _optimize_stage(self,
betas: torch.Tensor,
body_pose: torch.Tensor,
global_orient: torch.Tensor,
transl: torch.Tensor,
left_hand_pose: torch.Tensor,
right_hand_pose: torch.Tensor,
expression: torch.Tensor,
jaw_pose: torch.Tensor,
leye_pose: torch.Tensor,
reye_pose: torch.Tensor,
fit_global_orient: bool = True,
fit_transl: bool = True,
fit_body_pose: bool = True,
fit_betas: bool = True,
fit_left_hand_pose: bool = True,
fit_right_hand_pose: bool = True,
fit_expression: bool = True,
fit_jaw_pose: bool = True,
fit_leye_pose: bool = True,
fit_reye_pose: bool = True,
keypoints2d: torch.Tensor = None,
keypoints2d_conf: torch.Tensor = None,
keypoints2d_weight: float = None,
keypoints3d: torch.Tensor = None,
keypoints3d_conf: torch.Tensor = None,
keypoints3d_weight: float = None,
shape_prior_weight: float = None,
joint_prior_weight: float = None,
smooth_loss_weight: float = None,
pose_prior_weight: float = None,
pose_reg_weight: float = None,
limb_length_weight: float = None,
joint_weights: dict = {},
ftol: float = 1e-4,
num_iter: int = 1) -> None:
"""Optimize a stage of body model parameters according to
configuration.
Notes:
B: batch size
K: number of keypoints
D: shape dimension
Args:
betas: shape (B, D)
body_pose: shape (B, 69)
global_orient: shape (B, 3)
transl: shape (B, 3)
fit_global_orient: whether to optimize global_orient
fit_transl: whether to optimize transl
fit_body_pose: whether to optimize body_pose
fit_betas: whether to optimize betas
fit_left_hand_pose: whether to optimize left hand pose
fit_right_hand_pose: whether to optimize right hand pose
fit_expression: whether to optimize expression
fit_jaw_pose: whether to optimize jaw pose
fit_leye_pose: whether to optimize left eye pose
fit_reye_pose: whether to optimize right eye pose
keypoints2d: 2D keypoints of shape (B, K, 2)
keypoints2d_conf: 2D keypoint confidence of shape (B, K)
keypoints2d_weight: weight of 2D keypoint loss
keypoints3d: 3D keypoints of shape (B, K, 3).
keypoints3d_conf: 3D keypoint confidence of shape (B, K)
keypoints3d_weight: weight of 3D keypoint loss
shape_prior_weight: weight of shape prior loss
joint_prior_weight: weight of joint prior loss
smooth_loss_weight: weight of smooth loss
pose_prior_weight: weight of pose prior loss
pose_reg_weight: weight of pose regularization loss
limb_length_weight: weight of limb length loss
joint_weights: per joint weight of shape (K, )
num_iter: number of iterations
ftol: early stop tolerance for relative change in loss
Returns:
None
"""
parameters = OptimizableParameters()
parameters.set_param(fit_global_orient, global_orient)
parameters.set_param(fit_transl, transl)
parameters.set_param(fit_body_pose, body_pose)
parameters.set_param(fit_betas, betas)
parameters.set_param(fit_left_hand_pose, left_hand_pose)
parameters.set_param(fit_right_hand_pose, right_hand_pose)
parameters.set_param(fit_expression, expression)
parameters.set_param(fit_jaw_pose, jaw_pose)
parameters.set_param(fit_leye_pose, leye_pose)
parameters.set_param(fit_reye_pose, reye_pose)
optimizer = build_optimizer(parameters, self.optimizer)
pre_loss = None
for iter_idx in range(num_iter):
def closure():
# body_pose_fixed = use_reference_spine(body_pose,
# init_body_pose)
optimizer.zero_grad()
betas_video = self._expand_betas(body_pose.shape[0], betas)
loss_dict = self.evaluate(
global_orient=global_orient,
body_pose=body_pose,
betas=betas_video,
transl=transl,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
expression=expression,
jaw_pose=jaw_pose,
leye_pose=leye_pose,
reye_pose=reye_pose,
keypoints2d=keypoints2d,
keypoints2d_conf=keypoints2d_conf,
keypoints2d_weight=keypoints2d_weight,
keypoints3d=keypoints3d,
keypoints3d_conf=keypoints3d_conf,
keypoints3d_weight=keypoints3d_weight,
joint_prior_weight=joint_prior_weight,
shape_prior_weight=shape_prior_weight,
smooth_loss_weight=smooth_loss_weight,
pose_prior_weight=pose_prior_weight,
pose_reg_weight=pose_reg_weight,
limb_length_weight=limb_length_weight,
joint_weights=joint_weights)
loss = loss_dict['total_loss']
loss.backward()
return loss
loss = optimizer.step(closure)
if iter_idx > 0 and pre_loss is not None and ftol > 0:
loss_rel_change = self._compute_relative_change(
pre_loss, loss.item())
if loss_rel_change < ftol:
print(f'[ftol={ftol}] Early stop at {iter_idx} iter!')
break
pre_loss = loss.item()
def evaluate(
self,
betas: torch.Tensor = None,
body_pose: torch.Tensor = None,
global_orient: torch.Tensor = None,
transl: torch.Tensor = None,
left_hand_pose: torch.Tensor = None,
right_hand_pose: torch.Tensor = None,
expression: torch.Tensor = None,
jaw_pose: torch.Tensor = None,
leye_pose: torch.Tensor = None,
reye_pose: torch.Tensor = None,
keypoints2d: torch.Tensor = None,
keypoints2d_conf: torch.Tensor = None,
keypoints2d_weight: float = None,
keypoints3d: torch.Tensor = None,
keypoints3d_conf: torch.Tensor = None,
keypoints3d_weight: float = None,
shape_prior_weight: float = None,
joint_prior_weight: float = None,
smooth_loss_weight: float = None,
pose_prior_weight: float = None,
pose_reg_weight: float = None,
limb_length_weight: float = None,
joint_weights: dict = {},
return_verts: bool = False,
return_full_pose: bool = False,
return_joints: bool = False,
reduction_override: str = None,
):
"""Evaluate fitted parameters through loss computation. This function
serves two purposes: 1) internally, for loss backpropagation 2)
externally, for fitting quality evaluation.
Notes:
B: batch size
K: number of keypoints
D: body shape dimension
D_H: hand pose dimension
D_E: expression dimension
Args:
betas: shape (B, D)
body_pose: shape (B, 69)
global_orient: shape (B, 3)
transl: shape (B, 3)
left_hand_pose: shape (B, D_H)
right_hand_pose: shape (B, D_H)
expression: shape (B, D_E)
jaw_pose: shape (B, 3)
leye_pose: shape (B, 3)
reye_pose: shape (B, 3)
keypoints2d: 2D keypoints of shape (B, K, 2)
keypoints2d_conf: 2D keypoint confidence of shape (B, K)
keypoints2d_weight: weight of 2D keypoint loss
keypoints3d: 3D keypoints of shape (B, K, 3).
keypoints3d_conf: 3D keypoint confidence of shape (B, K)
keypoints3d_weight: weight of 3D keypoint loss
shape_prior_weight: weight of shape prior loss
joint_prior_weight: weight of joint prior loss
smooth_loss_weight: weight of smooth loss
pose_prior_weight: weight of pose prior loss
pose_reg_weight: weight of pose regularization loss
limb_length_weight: weight of limb length loss
joint_weights: per joint weight of shape (K, )
return_verts: whether to return vertices
return_joints: whether to return joints
return_full_pose: whether to return full pose
reduction_override: reduction method, e.g., 'none', 'sum', 'mean'
Returns:
ret: a dictionary that includes body model parameters,
and optional attributes such as vertices and joints
"""
ret = {}
body_model_output = self.body_model(global_orient=global_orient,
body_pose=body_pose,
betas=betas,
transl=transl,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
expression=expression,
jaw_pose=jaw_pose,
leye_pose=leye_pose,
reye_pose=reye_pose,
return_verts=return_verts,
return_full_pose=return_full_pose)
model_joints = body_model_output['joints']
model_joint_mask = body_model_output['joint_mask']
loss_dict = self._compute_loss(model_joints,
model_joint_mask,
keypoints2d=keypoints2d,
keypoints2d_conf=keypoints2d_conf,
keypoints2d_weight=keypoints2d_weight,
keypoints3d=keypoints3d,
keypoints3d_conf=keypoints3d_conf,
keypoints3d_weight=keypoints3d_weight,
joint_prior_weight=joint_prior_weight,
shape_prior_weight=shape_prior_weight,
smooth_loss_weight=smooth_loss_weight,
pose_prior_weight=pose_prior_weight,
pose_reg_weight=pose_reg_weight,
limb_length_weight=limb_length_weight,
joint_weights=joint_weights,
reduction_override=reduction_override,
body_pose=body_pose,
betas=betas)
ret.update(loss_dict)
if return_verts:
ret['vertices'] = body_model_output['vertices']
if return_full_pose:
ret['full_pose'] = body_model_output['full_pose']
if return_joints:
ret['joints'] = model_joints
return ret
def _set_keypoint_idxs(self):
"""Set keypoint indices to 1) body parts to be assigned different
weights 2) be ignored for keypoint loss computation.
Returns:
None
"""
convention = self.body_model.keypoint_dst
# obtain ignore keypoint indices
if self.ignore_keypoints is not None:
self.ignore_keypoint_idxs = []
for keypoint_name in self.ignore_keypoints:
keypoint_idx = get_keypoint_idx(keypoint_name,
convention=convention)
if keypoint_idx != -1:
self.ignore_keypoint_idxs.append(keypoint_idx)
# obtain body part keypoint indices
shoulder_keypoint_idxs = get_keypoint_idxs_by_part(
'shoulder', convention=convention)
hip_keypoint_idxs = get_keypoint_idxs_by_part('hip',
convention=convention)
self.shoulder_hip_keypoint_idxs = [
*shoulder_keypoint_idxs, *hip_keypoint_idxs
]
# head keypoints include all facial landmarks
self.face_keypoint_idxs = get_keypoint_idxs_by_part(
'head', convention=convention)
left_hand_keypoint_idxs = get_keypoint_idxs_by_part(
'left_hand', convention=convention)
right_hand_keypoint_idxs = get_keypoint_idxs_by_part(
'right_hand', convention=convention)
self.hand_keypoint_idxs = [
*left_hand_keypoint_idxs, *right_hand_keypoint_idxs
]
self.body_keypoint_idxs = get_keypoint_idxs_by_part(
'body', convention=convention)
def _get_weight(self,
use_shoulder_hip_only: bool = False,
body_weight: float = 1.0,
hand_weight: float = 1.0,
face_weight: float = 1.0):
"""Get per keypoint weight.
Notes:
K: number of keypoints
Args:
use_shoulder_hip_only: whether to use only shoulder and hip
keypoints for loss computation. This is useful in the
warming-up stage to find a reasonably good initialization.
body_weight: weight of body keypoints. Body part segmentation
definition is included in the HumanData convention.
hand_weight: weight of hand keypoints.
face_weight: weight of face keypoints.
Returns:
weight: per keypoint weight tensor of shape (K)
"""
num_keypoint = self.body_model.num_joints
if use_shoulder_hip_only:
weight = torch.zeros([num_keypoint]).to(self.device)
weight[self.shoulder_hip_keypoint_idxs] = 1.0
else:
weight = torch.ones([num_keypoint]).to(self.device)
weight[self.body_keypoint_idxs] = \
weight[self.body_keypoint_idxs] * body_weight
weight[self.hand_keypoint_idxs] = \
weight[self.hand_keypoint_idxs] * hand_weight
weight[self.face_keypoint_idxs] = \
weight[self.face_keypoint_idxs] * face_weight
if hasattr(self, 'ignore_keypoint_idxs'):
weight[self.ignore_keypoint_idxs] = 0.0
return weight