Spaces:
Running
on
L40S
Running
on
L40S
from abc import ABCMeta, abstractmethod | |
from typing import Optional, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from detrsmpl.core.conventions.keypoints_mapping import ( | |
get_keypoint_idx, | |
get_keypoint_idxs_by_part, | |
) | |
from detrsmpl.utils.geometry import ( | |
batch_rodrigues, | |
weak_perspective_projection, | |
) | |
from ..backbones.builder import build_backbone | |
from ..body_models.builder import build_body_model | |
from ..heads.builder import build_head | |
from ..losses.builder import build_loss | |
from ..necks.builder import build_neck | |
from ..utils import ( | |
SMPLXFaceCropFunc, | |
SMPLXFaceMergeFunc, | |
SMPLXHandCropFunc, | |
SMPLXHandMergeFunc, | |
) | |
from .base_architecture import BaseArchitecture | |
def set_requires_grad(nets, requires_grad=False): | |
"""Set requies_grad for all the networks. | |
Args: | |
nets (nn.Module | list[nn.Module]): A list of networks or a single | |
network. | |
requires_grad (bool): Whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
def pose2rotmat(pred_pose): | |
"""aa2rotmat.""" | |
if len(pred_pose.shape) == 3: | |
num_joints = pred_pose.shape[1] | |
pred_pose = batch_rodrigues(pred_pose.view(-1, 3)).view( | |
-1, num_joints, 3, 3) | |
return pred_pose | |
class SMPLXBodyModelEstimator(BaseArchitecture, metaclass=ABCMeta): | |
"""BodyModelEstimator Architecture. | |
Args: | |
backbone (dict | None, optional): Backbone config dict. Default: None. | |
neck (dict | None, optional): Neck config dict. Default: None | |
head (dict | None, optional): Regressor config dict. Default: None. | |
body_model_train (dict | None, optional): SMPL config dict during | |
training. Default: None. | |
body_model_test (dict | None, optional): SMPL config dict during | |
test. Default: None. | |
convention (str, optional): Keypoints convention. Default: "human_data" | |
loss_keypoints2d (dict | None, optional): Losses config dict for | |
2D keypoints. Default: None. | |
loss_keypoints3d (dict | None, optional): Losses config dict for | |
3D keypoints. Default: None. | |
loss_smplx_global_orient (dict | None, optional): Losses config dict | |
for smplx global orient. Default: None | |
loss_smplx_body_pose (dict | None, optional): Losses config dict | |
for smplx body pose. Default: None | |
loss_smplx_hand_pose (dict | None, optional): Losses config dict | |
for smplx hand pose. Default: None | |
loss_smplx_jaw_pose (dict | None, optional): Losses config dict | |
for smplx jaw pose. Default: None | |
loss_smplx_expression (dict | None, optional): Losses config dict | |
for smplx expression. Default: None | |
loss_smplx_betas (dict | None, optional): Losses config dict for smplx | |
betas. Default: None | |
loss_camera (dict | None, optional): Losses config dict for predicted | |
camera parameters. Default: None | |
extra_hand_model_cfg (dict | None, optional) : Hand model config for | |
refining body model prediction. Default: None | |
extra_face_model_cfg (dict | None, optional) : Face model config for | |
refining body model prediction. Default: None | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
backbone: Optional[Union[dict, None]] = None, | |
neck: Optional[Union[dict, None]] = None, | |
head: Optional[Union[dict, None]] = None, | |
body_model_train: Optional[Union[dict, None]] = None, | |
body_model_test: Optional[Union[dict, None]] = None, | |
convention: Optional[str] = 'human_data', | |
loss_keypoints2d: Optional[Union[dict, None]] = None, | |
loss_keypoints3d: Optional[Union[dict, None]] = None, | |
loss_smplx_global_orient: Optional[Union[dict, None]] = None, | |
loss_smplx_body_pose: Optional[Union[dict, None]] = None, | |
loss_smplx_hand_pose: Optional[Union[dict, None]] = None, | |
loss_smplx_jaw_pose: Optional[Union[dict, None]] = None, | |
loss_smplx_expression: Optional[Union[dict, None]] = None, | |
loss_smplx_betas: Optional[Union[dict, None]] = None, | |
loss_smplx_betas_prior: Optional[Union[dict, None]] = None, | |
loss_camera: Optional[Union[dict, None]] = None, | |
extra_hand_model_cfg: Optional[Union[dict, None]] = None, | |
extra_face_model_cfg: Optional[Union[dict, None]] = None, | |
frozen_batchnorm: bool = False, | |
init_cfg: Optional[Union[list, dict, None]] = None): | |
super(SMPLXBodyModelEstimator, self).__init__(init_cfg) | |
self.backbone = build_backbone(backbone) | |
self.neck = build_neck(neck) | |
self.head = build_head(head) | |
if frozen_batchnorm: | |
for param in self.backbone.parameters(): | |
param.requires_grad = False | |
for param in self.head.parameters(): | |
param.requires_grad = False | |
self.backbone = FrozenBatchNorm2d.convert_frozen_batchnorm( | |
self.backbone) | |
self.head = FrozenBatchNorm2d.convert_frozen_batchnorm(self.head) | |
self.body_model_train = build_body_model(body_model_train) | |
self.body_model_test = build_body_model(body_model_test) | |
self.convention = convention | |
self.apply_hand_model = False | |
self.apply_face_model = False | |
if extra_hand_model_cfg is not None: | |
self.hand_backbone = build_backbone( | |
extra_hand_model_cfg.get('backbone', None)) | |
self.hand_neck = build_neck(extra_hand_model_cfg.get('neck', None)) | |
self.hand_head = build_head(extra_hand_model_cfg.get('head', None)) | |
crop_cfg = extra_hand_model_cfg.get('crop_cfg', None) | |
if crop_cfg is not None: | |
self.crop_hand_func = SMPLXHandCropFunc( | |
self.hand_head, | |
self.body_model_train, | |
convention=self.convention, | |
**crop_cfg) | |
self.hand_merge_func = SMPLXHandMergeFunc( | |
self.body_model_train, self.convention) | |
self.hand_crop_loss = build_loss( | |
extra_hand_model_cfg.get('loss_hand_crop', None)) | |
self.apply_hand_model = True | |
self.left_hand_idxs = get_keypoint_idxs_by_part( | |
'left_hand', self.convention) | |
self.left_hand_idxs.append( | |
get_keypoint_idx('left_wrist', self.convention)) | |
self.left_hand_idxs = sorted(self.left_hand_idxs) | |
self.right_hand_idxs = get_keypoint_idxs_by_part( | |
'right_hand', self.convention) | |
self.right_hand_idxs.append( | |
get_keypoint_idx('right_wrist', self.convention)) | |
self.right_hand_idxs = sorted(self.right_hand_idxs) | |
if extra_face_model_cfg is not None: | |
self.face_backbone = build_backbone( | |
extra_face_model_cfg.get('backbone', None)) | |
self.face_neck = build_neck(extra_face_model_cfg.get('neck', None)) | |
self.face_head = build_head(extra_face_model_cfg.get('head', None)) | |
crop_cfg = extra_face_model_cfg.get('crop_cfg', None) | |
if crop_cfg is not None: | |
self.crop_face_func = SMPLXFaceCropFunc( | |
self.face_head, | |
self.body_model_train, | |
convention=self.convention, | |
**crop_cfg) | |
self.face_merge_func = SMPLXFaceMergeFunc( | |
self.body_model_train, self.convention) | |
self.face_crop_loss = build_loss( | |
extra_face_model_cfg.get('loss_face_crop', None)) | |
self.apply_face_model = True | |
self.face_idxs = get_keypoint_idxs_by_part('head', self.convention) | |
self.face_idxs = sorted(self.face_idxs) | |
self.loss_keypoints2d = build_loss(loss_keypoints2d) | |
self.loss_keypoints3d = build_loss(loss_keypoints3d) | |
self.loss_smplx_global_orient = build_loss(loss_smplx_global_orient) | |
self.loss_smplx_body_pose = build_loss(loss_smplx_body_pose) | |
self.loss_smplx_hand_pose = build_loss(loss_smplx_hand_pose) | |
self.loss_smplx_jaw_pose = build_loss(loss_smplx_jaw_pose) | |
self.loss_smplx_expression = build_loss(loss_smplx_expression) | |
self.loss_smplx_betas = build_loss(loss_smplx_betas) | |
self.loss_smplx_betas_piror = build_loss(loss_smplx_betas_prior) | |
self.loss_camera = build_loss(loss_camera) | |
set_requires_grad(self.body_model_train, False) | |
set_requires_grad(self.body_model_test, False) | |
def train_step(self, data_batch, optimizer, **kwargs): | |
"""Train step function. | |
Args: | |
data_batch (torch.Tensor): Batch of data as input. | |
optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for | |
generator. | |
Returns: | |
outputs (dict): Dict with loss, information for logger, | |
the number of samples. | |
""" | |
if self.backbone is not None: | |
img = data_batch['img'] | |
features = self.backbone(img) | |
else: | |
features = data_batch['features'] | |
if self.neck is not None: | |
features = self.neck(features) | |
predictions = self.head(features) | |
if self.apply_hand_model: | |
hand_input_img, hand_mean, hand_crop_info = self.crop_hand_func( | |
predictions, data_batch['img_metas']) | |
hand_features = self.hand_backbone(hand_input_img) | |
if self.neck is not None: | |
hand_features = self.hand_neck(hand_features) | |
hand_predictions = self.hand_head(hand_features, cond=hand_mean) | |
predictions = self.hand_merge_func(predictions, hand_predictions) | |
predictions['hand_crop_info'] = hand_crop_info | |
if self.apply_face_model: | |
face_input_img, face_mean, face_crop_info = self.crop_face_func( | |
predictions, data_batch['img_metas']) | |
face_features = self.face_backbone(face_input_img) | |
if self.neck is not None: | |
face_features = self.face_neck(face_features) | |
face_predictions = self.face_head(face_features, cond=face_mean) | |
predictions = self.face_merge_func(predictions, face_predictions) | |
predictions['face_crop_info'] = face_crop_info | |
targets = self.prepare_targets(data_batch) | |
losses = self.compute_losses(predictions, targets) | |
loss, log_vars = self._parse_losses(losses) | |
if self.backbone is not None: | |
optimizer['backbone'].zero_grad() | |
if self.neck is not None: | |
optimizer['neck'].zero_grad() | |
if self.head is not None: | |
optimizer['head'].zero_grad() | |
if self.apply_hand_model: | |
if self.hand_backbone is not None: | |
optimizer['hand_backbone'].zero_grad() | |
if self.hand_neck is not None: | |
optimizer['hand_neck'].zero_grad() | |
if self.hand_head is not None: | |
optimizer['hand_head'].zero_grad() | |
if self.apply_face_model: | |
if self.face_backbone is not None: | |
optimizer['face_backbone'].zero_grad() | |
if self.face_neck is not None: | |
optimizer['face_neck'].zero_grad() | |
if self.face_head is not None: | |
optimizer['face_head'].zero_grad() | |
loss.backward() | |
if self.backbone is not None: | |
optimizer['backbone'].step() | |
if self.neck is not None: | |
optimizer['neck'].step() | |
if self.head is not None: | |
optimizer['head'].step() | |
if self.apply_hand_model: | |
if self.hand_backbone is not None: | |
optimizer['hand_backbone'].step() | |
if self.hand_neck is not None: | |
optimizer['hand_neck'].step() | |
if self.hand_head is not None: | |
optimizer['hand_head'].step() | |
if self.apply_face_model: | |
if self.face_backbone is not None: | |
optimizer['face_backbone'].step() | |
if self.face_neck is not None: | |
optimizer['face_neck'].step() | |
if self.face_head is not None: | |
optimizer['face_head'].step() | |
outputs = dict(loss=loss, | |
log_vars=log_vars, | |
num_samples=len(next(iter(data_batch.values())))) | |
return outputs | |
def compute_keypoints3d_loss( | |
self, | |
pred_keypoints3d: torch.Tensor, | |
gt_keypoints3d: torch.Tensor, | |
has_keypoints3d: Optional[torch.Tensor] = None): | |
"""Compute loss for 3d keypoints.""" | |
keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1) | |
keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3) | |
pred_keypoints3d = pred_keypoints3d.float() | |
gt_keypoints3d = gt_keypoints3d[:, :, :3].float() | |
if has_keypoints3d is None: | |
has_keypoints3d = torch.ones((keypoints3d_conf.shape[0])) | |
if keypoints3d_conf[has_keypoints3d == 1].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_keypoints3d) | |
# Center the predictions using the pelvis | |
target_idxs = has_keypoints3d == 1 | |
pred_keypoints3d = pred_keypoints3d[target_idxs] | |
gt_keypoints3d = gt_keypoints3d[target_idxs] | |
pred_pelvis = pred_keypoints3d[:, [1, 2], :].mean(dim=1, keepdim=True) | |
pred_keypoints3d = pred_keypoints3d - pred_pelvis | |
gt_pelvis = gt_keypoints3d[:, [1, 2], :].mean(dim=1, keepdim=True) | |
gt_keypoints3d = gt_keypoints3d - gt_pelvis | |
loss = self.loss_keypoints3d(pred_keypoints3d, | |
gt_keypoints3d, | |
weight=keypoints3d_conf[target_idxs]) | |
loss /= gt_keypoints3d.shape[0] | |
return loss | |
def compute_keypoints2d_loss( | |
self, | |
pred_keypoints3d: torch.Tensor, | |
pred_cam: torch.Tensor, | |
gt_keypoints2d: torch.Tensor, | |
img_res: Optional[int] = 224, | |
focal_length: Optional[int] = 5000, | |
has_keypoints2d: Optional[torch.Tensor] = None): | |
"""Compute loss for 2d keypoints.""" | |
keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1) | |
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) | |
gt_keypoints2d = gt_keypoints2d[:, :, :2].float() | |
if has_keypoints2d is None: | |
has_keypoints2d = torch.ones((keypoints2d_conf.shape[0])) | |
if keypoints2d_conf[has_keypoints2d == 1].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_keypoints2d) | |
# Expose use weak_perspective_projection | |
pred_keypoints2d = weak_perspective_projection( | |
pred_keypoints3d, | |
scale=pred_cam[:, 0], | |
translation=pred_cam[:, 1:3]) | |
gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1 | |
target_idxs = has_keypoints2d == 1 | |
pred_keypoints2d = pred_keypoints2d[target_idxs] | |
gt_keypoints2d = gt_keypoints2d[target_idxs] | |
loss = self.loss_keypoints2d(pred_keypoints2d, | |
gt_keypoints2d, | |
weight=keypoints2d_conf[target_idxs]) | |
loss /= gt_keypoints2d.shape[0] | |
return loss | |
def compute_smplx_body_pose_loss(self, pred_rotmat: torch.Tensor, | |
gt_pose: torch.Tensor, | |
has_smplx_body_pose: torch.Tensor): | |
"""Compute loss for smplx body pose.""" | |
num_joints = pred_rotmat.shape[1] | |
target_idxs = has_smplx_body_pose == 1 | |
if gt_pose[target_idxs].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_pose) | |
gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view( | |
-1, num_joints, 3, 3) | |
loss = self.loss_smplx_body_pose(pred_rotmat[target_idxs], | |
gt_rotmat[target_idxs]) | |
return loss | |
def compute_smplx_global_orient_loss( | |
self, pred_rotmat: torch.Tensor, gt_global_orient: torch.Tensor, | |
has_smplx_global_orient: torch.Tensor): | |
"""Compute loss for smplx global orient.""" | |
target_idxs = has_smplx_global_orient == 1 | |
if gt_global_orient[target_idxs].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_global_orient) | |
gt_rotmat = batch_rodrigues(gt_global_orient.view(-1, 3)).view( | |
-1, 1, 3, 3) | |
loss = self.loss_smplx_global_orient(pred_rotmat[target_idxs], | |
gt_rotmat[target_idxs]) | |
return loss | |
def compute_smplx_jaw_pose_loss(self, pred_rotmat: torch.Tensor, | |
gt_jaw_pose: torch.Tensor, | |
has_smplx_jaw_pose: torch.Tensor, | |
face_conf: torch.Tensor): | |
"""Compute loss for smplx jaw pose.""" | |
target_idxs = has_smplx_jaw_pose == 1 | |
if gt_jaw_pose[target_idxs].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_jaw_pose) | |
gt_rotmat = batch_rodrigues(gt_jaw_pose.view(-1, 3)).view(-1, 1, 3, 3) | |
conf = face_conf.mean(axis=1).float() | |
conf = conf.view(-1, 1, 1, 1) | |
loss = self.loss_smplx_jaw_pose(pred_rotmat[target_idxs], | |
gt_rotmat[target_idxs], | |
weight=conf[target_idxs]) | |
return loss | |
def compute_smplx_hand_pose_loss(self, pred_rotmat: torch.Tensor, | |
gt_hand_pose: torch.Tensor, | |
has_smplx_hand_pose: torch.Tensor, | |
hand_conf: torch.Tensor): | |
"""Compute loss for smplx left/right hand pose.""" | |
joint_num = pred_rotmat.shape[1] | |
target_idxs = has_smplx_hand_pose == 1 | |
if gt_hand_pose[target_idxs].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_hand_pose) | |
gt_rotmat = batch_rodrigues(gt_hand_pose.view(-1, 3)).view( | |
-1, joint_num, 3, 3) | |
conf = hand_conf.mean(axis=1, | |
keepdim=True).float().expand(-1, joint_num) | |
conf = conf.view(-1, joint_num, 1, 1) | |
loss = self.loss_smplx_hand_pose(pred_rotmat[target_idxs], | |
gt_rotmat[target_idxs], | |
weight=conf[target_idxs]) | |
return loss | |
def compute_smplx_betas_loss(self, pred_betas: torch.Tensor, | |
gt_betas: torch.Tensor, | |
has_smplx_betas: torch.Tensor): | |
"""Compute loss for smplx betas.""" | |
target_idxs = has_smplx_betas == 1 | |
if gt_betas[target_idxs].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_betas) | |
loss = self.loss_smplx_betas(pred_betas[target_idxs], | |
gt_betas[target_idxs]) | |
loss = loss / gt_betas[target_idxs].shape[0] | |
return loss | |
def compute_smplx_betas_prior_loss(self, pred_betas: torch.Tensor): | |
"""Compute prior loss for smplx betas.""" | |
loss = self.loss_smplx_betas_piror(pred_betas) | |
return loss | |
def compute_smplx_expression_loss(self, pred_expression: torch.Tensor, | |
gt_expression: torch.Tensor, | |
has_smplx_expression: torch.Tensor, | |
face_conf: torch.Tensor): | |
"""Compute loss for smplx betas.""" | |
target_idxs = has_smplx_expression == 1 | |
if gt_expression[target_idxs].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_expression) | |
conf = face_conf.mean(axis=1).float() | |
conf = conf.view(-1, 1) | |
loss = self.loss_smplx_expression(pred_expression[target_idxs], | |
gt_expression[target_idxs], | |
weight=conf[target_idxs]) | |
loss = loss / gt_expression[target_idxs].shape[0] | |
return loss | |
def compute_camera_loss(self, cameras: torch.Tensor): | |
"""Compute loss for predicted camera parameters.""" | |
loss = self.loss_camera(cameras) | |
return loss | |
def compute_face_crop_loss(self, | |
pred_keypoints3d: torch.Tensor, | |
pred_cam: torch.Tensor, | |
gt_keypoints2d: torch.Tensor, | |
face_crop_info: dict, | |
img_res: Optional[int] = 224, | |
has_keypoints2d: Optional[torch.Tensor] = None): | |
"""Compute face crop loss for 2d keypoints.""" | |
keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1) | |
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) | |
gt_keypoints2d = gt_keypoints2d[:, :, :2].float() | |
if has_keypoints2d is None: | |
has_keypoints2d = torch.ones((keypoints2d_conf.shape[0])) | |
if keypoints2d_conf[has_keypoints2d == 1].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_keypoints2d) | |
# Expose use weak_perspective_projection | |
pred_keypoints2d = weak_perspective_projection( | |
pred_keypoints3d, | |
scale=pred_cam[:, 0], | |
translation=pred_cam[:, 1:3]) | |
target_idxs = has_keypoints2d == 1 | |
pred_keypoints2d = pred_keypoints2d[target_idxs] | |
gt_keypoints2d = gt_keypoints2d[target_idxs] | |
pred_keypoints2d = (0.5 * pred_keypoints2d + 0.5) * (img_res - 1) | |
face_inv_crop_transforms = face_crop_info['face_inv_crop_transforms'] | |
pred_keypoints2d_hd = torch.einsum('bij,bkj->bki', [ | |
face_inv_crop_transforms[:, :2, :2], pred_keypoints2d | |
]) + face_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1) | |
gt_keypoints2d_hd = torch.einsum('bij,bkj->bki', [ | |
face_inv_crop_transforms[:, :2, :2], gt_keypoints2d | |
]) + face_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1) | |
pred_face_keypoints_hd = pred_keypoints2d_hd[:, self.face_idxs] | |
face_crop_transform = face_crop_info['face_crop_transform'] | |
inv_face_crop_transf = torch.inverse(face_crop_transform) | |
face_img_keypoints = torch.einsum('bij,bkj->bki', [ | |
inv_face_crop_transf[:, :2, :2], pred_face_keypoints_hd | |
]) + inv_face_crop_transf[:, :2, 2].unsqueeze(dim=1) | |
gt_face_keypoints_hd = gt_keypoints2d_hd[:, self.face_idxs] | |
gt_face_keypoints = torch.einsum('bij,bkj->bki', [ | |
inv_face_crop_transf[:, :2, :2], gt_face_keypoints_hd | |
]) + inv_face_crop_transf[:, :2, 2].unsqueeze(dim=1) | |
loss = self.face_crop_loss( | |
face_img_keypoints, | |
gt_face_keypoints, | |
weight=keypoints2d_conf[target_idxs][:, self.face_idxs]) | |
loss /= gt_face_keypoints.shape[0] | |
return loss | |
def compute_hand_crop_loss(self, | |
pred_keypoints3d: torch.Tensor, | |
pred_cam: torch.Tensor, | |
gt_keypoints2d: torch.Tensor, | |
hand_crop_info: dict, | |
img_res: Optional[int] = 224, | |
has_keypoints2d: Optional[torch.Tensor] = None): | |
"""Compute hand crop loss for 2d keypoints.""" | |
keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1) | |
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) | |
gt_keypoints2d = gt_keypoints2d[:, :, :2].float() | |
if has_keypoints2d is None: | |
has_keypoints2d = torch.ones((keypoints2d_conf.shape[0])) | |
if keypoints2d_conf[has_keypoints2d == 1].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_keypoints2d) | |
# Expose use weak_perspective_projection | |
pred_keypoints2d = weak_perspective_projection( | |
pred_keypoints3d, | |
scale=pred_cam[:, 0], | |
translation=pred_cam[:, 1:3]) | |
target_idxs = has_keypoints2d == 1 | |
pred_keypoints2d = pred_keypoints2d[target_idxs] | |
gt_keypoints2d = gt_keypoints2d[target_idxs] | |
pred_keypoints2d = (0.5 * pred_keypoints2d + 0.5) * (img_res - 1) | |
hand_inv_crop_transforms = hand_crop_info['hand_inv_crop_transforms'] | |
pred_keypoints2d_hd = torch.einsum('bij,bkj->bki', [ | |
hand_inv_crop_transforms[:, :2, :2], pred_keypoints2d | |
]) + hand_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1) | |
gt_keypoints2d_hd = torch.einsum('bij,bkj->bki', [ | |
hand_inv_crop_transforms[:, :2, :2], gt_keypoints2d | |
]) + hand_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1) | |
pred_left_hand_keypoints_hd = pred_keypoints2d_hd[:, | |
self.left_hand_idxs] | |
left_hand_crop_transform = hand_crop_info['left_hand_crop_transform'] | |
inv_left_hand_crop_transf = torch.inverse(left_hand_crop_transform) | |
left_hand_img_keypoints = torch.einsum('bij,bkj->bki', [ | |
inv_left_hand_crop_transf[:, :2, :2], pred_left_hand_keypoints_hd | |
]) + inv_left_hand_crop_transf[:, :2, 2].unsqueeze(dim=1) | |
gt_left_hand_keypoints_hd = gt_keypoints2d_hd[:, self.left_hand_idxs] | |
gt_left_hand_keypoints = torch.einsum('bij,bkj->bki', [ | |
inv_left_hand_crop_transf[:, :2, :2], gt_left_hand_keypoints_hd | |
]) + inv_left_hand_crop_transf[:, :2, 2].unsqueeze(dim=1) | |
pred_right_hand_keypoints_hd = pred_keypoints2d_hd[:, self. | |
right_hand_idxs] | |
right_hand_crop_transform = hand_crop_info['right_hand_crop_transform'] | |
inv_right_hand_crop_transf = torch.inverse(right_hand_crop_transform) | |
right_hand_img_keypoints = torch.einsum('bij,bkj->bki', [ | |
inv_right_hand_crop_transf[:, :2, :2], pred_right_hand_keypoints_hd | |
]) + inv_right_hand_crop_transf[:, :2, 2].unsqueeze(dim=1) | |
gt_right_hand_keypoints_hd = gt_keypoints2d_hd[:, self.right_hand_idxs] | |
gt_right_hand_keypoints = torch.einsum('bij,bkj->bki', [ | |
inv_right_hand_crop_transf[:, :2, :2], gt_right_hand_keypoints_hd | |
]) + inv_right_hand_crop_transf[:, :2, 2].unsqueeze(dim=1) | |
left_loss = self.hand_crop_loss( | |
left_hand_img_keypoints, | |
gt_left_hand_keypoints, | |
weight=keypoints2d_conf[target_idxs][:, self.left_hand_idxs]) | |
left_loss /= gt_left_hand_keypoints.shape[0] | |
right_loss = self.hand_crop_loss( | |
right_hand_img_keypoints, | |
gt_right_hand_keypoints, | |
weight=keypoints2d_conf[target_idxs][:, self.right_hand_idxs]) | |
right_loss /= gt_right_hand_keypoints.shape[0] | |
return left_loss + right_loss | |
def compute_losses(self, predictions: dict, targets: dict): | |
"""Compute losses.""" | |
pred_param = predictions['pred_param'] | |
pred_cam = predictions['pred_cam'] | |
gt_keypoints3d = targets['keypoints3d'] | |
gt_keypoints2d = targets['keypoints2d'] | |
if self.body_model_train is not None: | |
pred_output = self.body_model_train(**pred_param) | |
pred_keypoints3d = pred_output['joints'] | |
if 'has_keypoints3d' in targets: | |
has_keypoints3d = targets['has_keypoints3d'].squeeze(-1) | |
else: | |
has_keypoints3d = None | |
if 'has_keypoints2d' in targets: | |
has_keypoints2d = targets['has_keypoints2d'].squeeze(-1) | |
else: | |
has_keypoints2d = None | |
losses = {} | |
if self.loss_keypoints3d is not None: | |
losses['keypoints3d_loss'] = self.compute_keypoints3d_loss( | |
pred_keypoints3d, | |
gt_keypoints3d, | |
has_keypoints3d=has_keypoints3d) | |
if self.loss_keypoints2d is not None: | |
losses['keypoints2d_loss'] = self.compute_keypoints2d_loss( | |
pred_keypoints3d, | |
pred_cam, | |
gt_keypoints2d, | |
img_res=targets['img'].shape[-1], | |
has_keypoints2d=has_keypoints2d) | |
if self.loss_smplx_global_orient is not None: | |
pred_global_orient = pred_param['global_orient'] | |
pred_global_orient = pose2rotmat(pred_global_orient) | |
gt_global_orient = targets['smplx_global_orient'] | |
has_smplx_global_orient = targets[ | |
'has_smplx_global_orient'].squeeze(-1) | |
losses['smplx_global_orient_loss'] = \ | |
self.compute_smplx_global_orient_loss( | |
pred_global_orient, gt_global_orient, | |
has_smplx_global_orient) | |
if self.loss_smplx_body_pose is not None: | |
pred_pose = pred_param['body_pose'] | |
pred_pose = pose2rotmat(pred_pose) | |
gt_pose = targets['smplx_body_pose'] | |
has_smplx_body_pose = targets['has_smplx_body_pose'].squeeze(-1) | |
losses['smplx_body_pose_loss'] = \ | |
self.compute_smplx_body_pose_loss( | |
pred_pose, gt_pose, has_smplx_body_pose) | |
if self.loss_smplx_jaw_pose is not None: | |
pred_jaw_pose = pred_param['jaw_pose'] | |
pred_jaw_pose = pose2rotmat(pred_jaw_pose) | |
gt_jaw_pose = targets['smplx_jaw_pose'] | |
face_conf = get_keypoint_idxs_by_part('head', self.convention) | |
has_smplx_jaw_pose = targets['has_smplx_jaw_pose'].squeeze(-1) | |
losses['smplx_jaw_pose_loss'] = self.compute_smplx_jaw_pose_loss( | |
pred_jaw_pose, gt_jaw_pose, has_smplx_jaw_pose, | |
gt_keypoints2d[:, face_conf, 2]) | |
if self.loss_smplx_hand_pose is not None: | |
pred_right_hand_pose = pred_param['right_hand_pose'] | |
pred_right_hand_pose = pose2rotmat(pred_right_hand_pose) | |
gt_right_hand_pose = targets['smplx_right_hand_pose'] | |
right_hand_conf = get_keypoint_idxs_by_part( | |
'right_hand', self.convention) | |
has_smplx_right_hand_pose = targets[ | |
'has_smplx_right_hand_pose'].squeeze(-1) | |
losses['smplx_right_hand_pose_loss'] = \ | |
self.compute_smplx_hand_pose_loss( | |
pred_right_hand_pose, gt_right_hand_pose, | |
has_smplx_right_hand_pose, | |
gt_keypoints2d[:, right_hand_conf, 2]) | |
if 'left_hand_pose' in pred_param: | |
pred_left_hand_pose = pred_param['left_hand_pose'] | |
pred_left_hand_pose = pose2rotmat(pred_left_hand_pose) | |
gt_left_hand_pose = targets['smplx_left_hand_pose'] | |
left_hand_conf = get_keypoint_idxs_by_part( | |
'left_hand', self.convention) | |
has_smplx_left_hand_pose = targets[ | |
'has_smplx_left_hand_pose'].squeeze(-1) | |
losses['smplx_left_hand_pose_loss'] = \ | |
self.compute_smplx_hand_pose_loss( | |
pred_left_hand_pose, gt_left_hand_pose, | |
has_smplx_left_hand_pose, | |
gt_keypoints2d[:, left_hand_conf, 2]) | |
if self.loss_smplx_betas is not None: | |
pred_betas = pred_param['betas'] | |
gt_betas = targets['smplx_betas'] | |
has_smplx_betas = targets['has_smplx_betas'].squeeze(-1) | |
losses['smplx_betas_loss'] = self.compute_smplx_betas_loss( | |
pred_betas, gt_betas, has_smplx_betas) | |
if self.loss_smplx_expression is not None: | |
pred_expression = pred_param['expression'] | |
gt_expression = targets['smplx_expression'] | |
face_conf = get_keypoint_idxs_by_part('head', self.convention) | |
has_smplx_expression = targets['has_smplx_expression'].squeeze(-1) | |
losses[ | |
'smplx_expression_loss'] = self.compute_smplx_expression_loss( | |
pred_expression, gt_expression, has_smplx_expression, | |
gt_keypoints2d[:, face_conf, 2]) | |
if self.loss_smplx_betas_piror is not None: | |
pred_betas = pred_param['betas'] | |
losses['smplx_betas_prior_loss'] = \ | |
self.compute_smplx_betas_prior_loss( | |
pred_betas) | |
if self.loss_camera is not None: | |
losses['camera_loss'] = self.compute_camera_loss(pred_cam) | |
if self.apply_hand_model and self.hand_crop_loss is not None: | |
losses['hand_crop_loss'] = self.compute_hand_crop_loss( | |
pred_keypoints3d, pred_cam, gt_keypoints2d, | |
predictions['hand_crop_info'], targets['img'].shape[-1], | |
has_keypoints2d) | |
if self.apply_face_model and self.face_crop_loss is not None: | |
losses['face_crop_loss'] = self.compute_face_crop_loss( | |
pred_keypoints3d, pred_cam, gt_keypoints2d, | |
predictions['face_crop_info'], targets['img'].shape[-1], | |
has_keypoints2d) | |
return losses | |
def prepare_targets(self, data_batch): | |
pass | |
def forward_train(self, **kwargs): | |
"""Forward function for general training. | |
For mesh estimation, we do not use this interface. | |
""" | |
raise NotImplementedError('This interface should not be used in ' | |
'current training schedule. Please use ' | |
'`train_step` for training.') | |
def forward_test(self, img, img_metas, **kwargs): | |
"""Defines the computation performed at every call when testing.""" | |
pass | |
class SMPLXImageBodyModelEstimator(SMPLXBodyModelEstimator): | |
def prepare_targets(self, data_batch: dict): | |
# Image Mesh Estimator does not need extra process for ground truth | |
return data_batch | |
def forward_test(self, img: torch.Tensor, img_metas: dict, **kwargs): | |
"""Defines the computation performed at every call when testing.""" | |
if self.backbone is not None: | |
features = self.backbone(img) | |
else: | |
features = kwargs['features'] | |
if self.neck is not None: | |
features = self.neck(features) | |
predictions = self.head(features) | |
if self.apply_hand_model: | |
hand_input_img, hand_mean, hand_crop_info = self.crop_hand_func( | |
predictions, img_metas) | |
hand_features = self.hand_backbone(hand_input_img) | |
if self.neck is not None: | |
hand_features = self.hand_neck(hand_features) | |
hand_predictions = self.hand_head(hand_features, cond=hand_mean) | |
predictions = self.hand_merge_func(predictions, hand_predictions) | |
predictions['hand_crop_info'] = hand_crop_info | |
if self.apply_face_model: | |
face_input_img, face_mean, face_crop_info = self.crop_face_func( | |
predictions, img_metas) | |
face_features = self.face_backbone(face_input_img) | |
if self.neck is not None: | |
face_features = self.face_neck(face_features) | |
face_predictions = self.face_head(face_features, cond=face_mean) | |
predictions = self.face_merge_func(predictions, face_predictions) | |
predictions['face_crop_info'] = face_crop_info | |
pred_param = predictions['pred_param'] | |
pred_cam = predictions['pred_cam'] | |
pred_output = self.body_model_test(**pred_param) | |
pred_vertices = pred_output['vertices'] | |
pred_keypoints_3d = pred_output['joints'] | |
all_preds = {} | |
all_preds['keypoints_3d'] = pred_keypoints_3d.detach().cpu().numpy() | |
for value in pred_param.values(): | |
if isinstance(value, torch.Tensor): | |
value = value.detach().cpu().numpy() | |
all_preds['param'] = pred_param | |
all_preds['camera'] = pred_cam.detach().cpu().numpy() | |
all_preds['vertices'] = pred_vertices.detach().cpu().numpy() | |
image_path = [] | |
for img_meta in img_metas: | |
image_path.append(img_meta['image_path']) | |
all_preds['image_path'] = image_path | |
all_preds['image_idx'] = kwargs['sample_idx'] | |
return all_preds | |
class FrozenBatchNorm2d(nn.Module): | |
"""BatchNorm2d where the batch statistics and the affine parameters are | |
fixed.""" | |
def __init__(self, n): | |
super(FrozenBatchNorm2d, self).__init__() | |
self.register_buffer('weight', torch.ones(n)) | |
self.register_buffer('bias', torch.zeros(n)) | |
self.register_buffer('running_mean', torch.zeros(n)) | |
self.register_buffer('running_var', torch.ones(n)) | |
def from_bn(module: nn.BatchNorm2d): | |
"""Initializes a frozen batch norm module from a batch norm module.""" | |
dim = len(module.weight.data) | |
frozen_module = FrozenBatchNorm2d(dim) | |
frozen_module.weight.data = module.weight.data | |
missing, not_found = frozen_module.load_state_dict(module.state_dict(), | |
strict=False) | |
return frozen_module | |
def convert_frozen_batchnorm(cls, module): | |
"""Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. | |
Args: | |
module (torch.nn.Module): | |
Returns: | |
If module is BatchNorm/SyncBatchNorm, returns a new module. | |
Otherwise, in-place convert module and return it. | |
Similar to convert_sync_batchnorm in | |
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py | |
""" | |
bn_module = nn.modules.batchnorm | |
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) | |
res = module | |
if isinstance(module, bn_module): | |
res = cls(module.num_features) | |
if module.affine: | |
res.weight.data = module.weight.data.clone().detach() | |
res.bias.data = module.bias.data.clone().detach() | |
res.running_mean.data = module.running_mean.data | |
res.running_var.data = module.running_var.data | |
res.eps = module.eps | |
else: | |
for name, child in module.named_children(): | |
new_child = cls.convert_frozen_batchnorm(child) | |
if new_child is not child: | |
res.add_module(name, new_child) | |
return res | |
def forward(self, x): | |
# Cast all fixed parameters to half() if necessary | |
if x.dtype == torch.float16: | |
self.weight = self.weight.half() | |
self.bias = self.bias.half() | |
self.running_mean = self.running_mean.half() | |
self.running_var = self.running_var.half() | |
return F.batch_norm(x, self.running_mean, self.running_var, | |
self.weight, self.bias, False) | |