AiOS / detrsmpl /models /architectures /expressive_mesh_estimator.py
ttxskk
update
d7e58f0
from abc import ABCMeta, abstractmethod
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from detrsmpl.core.conventions.keypoints_mapping import (
get_keypoint_idx,
get_keypoint_idxs_by_part,
)
from detrsmpl.utils.geometry import (
batch_rodrigues,
weak_perspective_projection,
)
from ..backbones.builder import build_backbone
from ..body_models.builder import build_body_model
from ..heads.builder import build_head
from ..losses.builder import build_loss
from ..necks.builder import build_neck
from ..utils import (
SMPLXFaceCropFunc,
SMPLXFaceMergeFunc,
SMPLXHandCropFunc,
SMPLXHandMergeFunc,
)
from .base_architecture import BaseArchitecture
def set_requires_grad(nets, requires_grad=False):
"""Set requies_grad for all the networks.
Args:
nets (nn.Module | list[nn.Module]): A list of networks or a single
network.
requires_grad (bool): Whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def pose2rotmat(pred_pose):
"""aa2rotmat."""
if len(pred_pose.shape) == 3:
num_joints = pred_pose.shape[1]
pred_pose = batch_rodrigues(pred_pose.view(-1, 3)).view(
-1, num_joints, 3, 3)
return pred_pose
class SMPLXBodyModelEstimator(BaseArchitecture, metaclass=ABCMeta):
"""BodyModelEstimator Architecture.
Args:
backbone (dict | None, optional): Backbone config dict. Default: None.
neck (dict | None, optional): Neck config dict. Default: None
head (dict | None, optional): Regressor config dict. Default: None.
body_model_train (dict | None, optional): SMPL config dict during
training. Default: None.
body_model_test (dict | None, optional): SMPL config dict during
test. Default: None.
convention (str, optional): Keypoints convention. Default: "human_data"
loss_keypoints2d (dict | None, optional): Losses config dict for
2D keypoints. Default: None.
loss_keypoints3d (dict | None, optional): Losses config dict for
3D keypoints. Default: None.
loss_smplx_global_orient (dict | None, optional): Losses config dict
for smplx global orient. Default: None
loss_smplx_body_pose (dict | None, optional): Losses config dict
for smplx body pose. Default: None
loss_smplx_hand_pose (dict | None, optional): Losses config dict
for smplx hand pose. Default: None
loss_smplx_jaw_pose (dict | None, optional): Losses config dict
for smplx jaw pose. Default: None
loss_smplx_expression (dict | None, optional): Losses config dict
for smplx expression. Default: None
loss_smplx_betas (dict | None, optional): Losses config dict for smplx
betas. Default: None
loss_camera (dict | None, optional): Losses config dict for predicted
camera parameters. Default: None
extra_hand_model_cfg (dict | None, optional) : Hand model config for
refining body model prediction. Default: None
extra_face_model_cfg (dict | None, optional) : Face model config for
refining body model prediction. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
backbone: Optional[Union[dict, None]] = None,
neck: Optional[Union[dict, None]] = None,
head: Optional[Union[dict, None]] = None,
body_model_train: Optional[Union[dict, None]] = None,
body_model_test: Optional[Union[dict, None]] = None,
convention: Optional[str] = 'human_data',
loss_keypoints2d: Optional[Union[dict, None]] = None,
loss_keypoints3d: Optional[Union[dict, None]] = None,
loss_smplx_global_orient: Optional[Union[dict, None]] = None,
loss_smplx_body_pose: Optional[Union[dict, None]] = None,
loss_smplx_hand_pose: Optional[Union[dict, None]] = None,
loss_smplx_jaw_pose: Optional[Union[dict, None]] = None,
loss_smplx_expression: Optional[Union[dict, None]] = None,
loss_smplx_betas: Optional[Union[dict, None]] = None,
loss_smplx_betas_prior: Optional[Union[dict, None]] = None,
loss_camera: Optional[Union[dict, None]] = None,
extra_hand_model_cfg: Optional[Union[dict, None]] = None,
extra_face_model_cfg: Optional[Union[dict, None]] = None,
frozen_batchnorm: bool = False,
init_cfg: Optional[Union[list, dict, None]] = None):
super(SMPLXBodyModelEstimator, self).__init__(init_cfg)
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
self.head = build_head(head)
if frozen_batchnorm:
for param in self.backbone.parameters():
param.requires_grad = False
for param in self.head.parameters():
param.requires_grad = False
self.backbone = FrozenBatchNorm2d.convert_frozen_batchnorm(
self.backbone)
self.head = FrozenBatchNorm2d.convert_frozen_batchnorm(self.head)
self.body_model_train = build_body_model(body_model_train)
self.body_model_test = build_body_model(body_model_test)
self.convention = convention
self.apply_hand_model = False
self.apply_face_model = False
if extra_hand_model_cfg is not None:
self.hand_backbone = build_backbone(
extra_hand_model_cfg.get('backbone', None))
self.hand_neck = build_neck(extra_hand_model_cfg.get('neck', None))
self.hand_head = build_head(extra_hand_model_cfg.get('head', None))
crop_cfg = extra_hand_model_cfg.get('crop_cfg', None)
if crop_cfg is not None:
self.crop_hand_func = SMPLXHandCropFunc(
self.hand_head,
self.body_model_train,
convention=self.convention,
**crop_cfg)
self.hand_merge_func = SMPLXHandMergeFunc(
self.body_model_train, self.convention)
self.hand_crop_loss = build_loss(
extra_hand_model_cfg.get('loss_hand_crop', None))
self.apply_hand_model = True
self.left_hand_idxs = get_keypoint_idxs_by_part(
'left_hand', self.convention)
self.left_hand_idxs.append(
get_keypoint_idx('left_wrist', self.convention))
self.left_hand_idxs = sorted(self.left_hand_idxs)
self.right_hand_idxs = get_keypoint_idxs_by_part(
'right_hand', self.convention)
self.right_hand_idxs.append(
get_keypoint_idx('right_wrist', self.convention))
self.right_hand_idxs = sorted(self.right_hand_idxs)
if extra_face_model_cfg is not None:
self.face_backbone = build_backbone(
extra_face_model_cfg.get('backbone', None))
self.face_neck = build_neck(extra_face_model_cfg.get('neck', None))
self.face_head = build_head(extra_face_model_cfg.get('head', None))
crop_cfg = extra_face_model_cfg.get('crop_cfg', None)
if crop_cfg is not None:
self.crop_face_func = SMPLXFaceCropFunc(
self.face_head,
self.body_model_train,
convention=self.convention,
**crop_cfg)
self.face_merge_func = SMPLXFaceMergeFunc(
self.body_model_train, self.convention)
self.face_crop_loss = build_loss(
extra_face_model_cfg.get('loss_face_crop', None))
self.apply_face_model = True
self.face_idxs = get_keypoint_idxs_by_part('head', self.convention)
self.face_idxs = sorted(self.face_idxs)
self.loss_keypoints2d = build_loss(loss_keypoints2d)
self.loss_keypoints3d = build_loss(loss_keypoints3d)
self.loss_smplx_global_orient = build_loss(loss_smplx_global_orient)
self.loss_smplx_body_pose = build_loss(loss_smplx_body_pose)
self.loss_smplx_hand_pose = build_loss(loss_smplx_hand_pose)
self.loss_smplx_jaw_pose = build_loss(loss_smplx_jaw_pose)
self.loss_smplx_expression = build_loss(loss_smplx_expression)
self.loss_smplx_betas = build_loss(loss_smplx_betas)
self.loss_smplx_betas_piror = build_loss(loss_smplx_betas_prior)
self.loss_camera = build_loss(loss_camera)
set_requires_grad(self.body_model_train, False)
set_requires_grad(self.body_model_test, False)
def train_step(self, data_batch, optimizer, **kwargs):
"""Train step function.
Args:
data_batch (torch.Tensor): Batch of data as input.
optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for
generator.
Returns:
outputs (dict): Dict with loss, information for logger,
the number of samples.
"""
if self.backbone is not None:
img = data_batch['img']
features = self.backbone(img)
else:
features = data_batch['features']
if self.neck is not None:
features = self.neck(features)
predictions = self.head(features)
if self.apply_hand_model:
hand_input_img, hand_mean, hand_crop_info = self.crop_hand_func(
predictions, data_batch['img_metas'])
hand_features = self.hand_backbone(hand_input_img)
if self.neck is not None:
hand_features = self.hand_neck(hand_features)
hand_predictions = self.hand_head(hand_features, cond=hand_mean)
predictions = self.hand_merge_func(predictions, hand_predictions)
predictions['hand_crop_info'] = hand_crop_info
if self.apply_face_model:
face_input_img, face_mean, face_crop_info = self.crop_face_func(
predictions, data_batch['img_metas'])
face_features = self.face_backbone(face_input_img)
if self.neck is not None:
face_features = self.face_neck(face_features)
face_predictions = self.face_head(face_features, cond=face_mean)
predictions = self.face_merge_func(predictions, face_predictions)
predictions['face_crop_info'] = face_crop_info
targets = self.prepare_targets(data_batch)
losses = self.compute_losses(predictions, targets)
loss, log_vars = self._parse_losses(losses)
if self.backbone is not None:
optimizer['backbone'].zero_grad()
if self.neck is not None:
optimizer['neck'].zero_grad()
if self.head is not None:
optimizer['head'].zero_grad()
if self.apply_hand_model:
if self.hand_backbone is not None:
optimizer['hand_backbone'].zero_grad()
if self.hand_neck is not None:
optimizer['hand_neck'].zero_grad()
if self.hand_head is not None:
optimizer['hand_head'].zero_grad()
if self.apply_face_model:
if self.face_backbone is not None:
optimizer['face_backbone'].zero_grad()
if self.face_neck is not None:
optimizer['face_neck'].zero_grad()
if self.face_head is not None:
optimizer['face_head'].zero_grad()
loss.backward()
if self.backbone is not None:
optimizer['backbone'].step()
if self.neck is not None:
optimizer['neck'].step()
if self.head is not None:
optimizer['head'].step()
if self.apply_hand_model:
if self.hand_backbone is not None:
optimizer['hand_backbone'].step()
if self.hand_neck is not None:
optimizer['hand_neck'].step()
if self.hand_head is not None:
optimizer['hand_head'].step()
if self.apply_face_model:
if self.face_backbone is not None:
optimizer['face_backbone'].step()
if self.face_neck is not None:
optimizer['face_neck'].step()
if self.face_head is not None:
optimizer['face_head'].step()
outputs = dict(loss=loss,
log_vars=log_vars,
num_samples=len(next(iter(data_batch.values()))))
return outputs
def compute_keypoints3d_loss(
self,
pred_keypoints3d: torch.Tensor,
gt_keypoints3d: torch.Tensor,
has_keypoints3d: Optional[torch.Tensor] = None):
"""Compute loss for 3d keypoints."""
keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1)
keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3)
pred_keypoints3d = pred_keypoints3d.float()
gt_keypoints3d = gt_keypoints3d[:, :, :3].float()
if has_keypoints3d is None:
has_keypoints3d = torch.ones((keypoints3d_conf.shape[0]))
if keypoints3d_conf[has_keypoints3d == 1].numel() == 0:
return torch.Tensor([0]).type_as(gt_keypoints3d)
# Center the predictions using the pelvis
target_idxs = has_keypoints3d == 1
pred_keypoints3d = pred_keypoints3d[target_idxs]
gt_keypoints3d = gt_keypoints3d[target_idxs]
pred_pelvis = pred_keypoints3d[:, [1, 2], :].mean(dim=1, keepdim=True)
pred_keypoints3d = pred_keypoints3d - pred_pelvis
gt_pelvis = gt_keypoints3d[:, [1, 2], :].mean(dim=1, keepdim=True)
gt_keypoints3d = gt_keypoints3d - gt_pelvis
loss = self.loss_keypoints3d(pred_keypoints3d,
gt_keypoints3d,
weight=keypoints3d_conf[target_idxs])
loss /= gt_keypoints3d.shape[0]
return loss
def compute_keypoints2d_loss(
self,
pred_keypoints3d: torch.Tensor,
pred_cam: torch.Tensor,
gt_keypoints2d: torch.Tensor,
img_res: Optional[int] = 224,
focal_length: Optional[int] = 5000,
has_keypoints2d: Optional[torch.Tensor] = None):
"""Compute loss for 2d keypoints."""
keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
if has_keypoints2d is None:
has_keypoints2d = torch.ones((keypoints2d_conf.shape[0]))
if keypoints2d_conf[has_keypoints2d == 1].numel() == 0:
return torch.Tensor([0]).type_as(gt_keypoints2d)
# Expose use weak_perspective_projection
pred_keypoints2d = weak_perspective_projection(
pred_keypoints3d,
scale=pred_cam[:, 0],
translation=pred_cam[:, 1:3])
gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1
target_idxs = has_keypoints2d == 1
pred_keypoints2d = pred_keypoints2d[target_idxs]
gt_keypoints2d = gt_keypoints2d[target_idxs]
loss = self.loss_keypoints2d(pred_keypoints2d,
gt_keypoints2d,
weight=keypoints2d_conf[target_idxs])
loss /= gt_keypoints2d.shape[0]
return loss
def compute_smplx_body_pose_loss(self, pred_rotmat: torch.Tensor,
gt_pose: torch.Tensor,
has_smplx_body_pose: torch.Tensor):
"""Compute loss for smplx body pose."""
num_joints = pred_rotmat.shape[1]
target_idxs = has_smplx_body_pose == 1
if gt_pose[target_idxs].numel() == 0:
return torch.Tensor([0]).type_as(gt_pose)
gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(
-1, num_joints, 3, 3)
loss = self.loss_smplx_body_pose(pred_rotmat[target_idxs],
gt_rotmat[target_idxs])
return loss
def compute_smplx_global_orient_loss(
self, pred_rotmat: torch.Tensor, gt_global_orient: torch.Tensor,
has_smplx_global_orient: torch.Tensor):
"""Compute loss for smplx global orient."""
target_idxs = has_smplx_global_orient == 1
if gt_global_orient[target_idxs].numel() == 0:
return torch.Tensor([0]).type_as(gt_global_orient)
gt_rotmat = batch_rodrigues(gt_global_orient.view(-1, 3)).view(
-1, 1, 3, 3)
loss = self.loss_smplx_global_orient(pred_rotmat[target_idxs],
gt_rotmat[target_idxs])
return loss
def compute_smplx_jaw_pose_loss(self, pred_rotmat: torch.Tensor,
gt_jaw_pose: torch.Tensor,
has_smplx_jaw_pose: torch.Tensor,
face_conf: torch.Tensor):
"""Compute loss for smplx jaw pose."""
target_idxs = has_smplx_jaw_pose == 1
if gt_jaw_pose[target_idxs].numel() == 0:
return torch.Tensor([0]).type_as(gt_jaw_pose)
gt_rotmat = batch_rodrigues(gt_jaw_pose.view(-1, 3)).view(-1, 1, 3, 3)
conf = face_conf.mean(axis=1).float()
conf = conf.view(-1, 1, 1, 1)
loss = self.loss_smplx_jaw_pose(pred_rotmat[target_idxs],
gt_rotmat[target_idxs],
weight=conf[target_idxs])
return loss
def compute_smplx_hand_pose_loss(self, pred_rotmat: torch.Tensor,
gt_hand_pose: torch.Tensor,
has_smplx_hand_pose: torch.Tensor,
hand_conf: torch.Tensor):
"""Compute loss for smplx left/right hand pose."""
joint_num = pred_rotmat.shape[1]
target_idxs = has_smplx_hand_pose == 1
if gt_hand_pose[target_idxs].numel() == 0:
return torch.Tensor([0]).type_as(gt_hand_pose)
gt_rotmat = batch_rodrigues(gt_hand_pose.view(-1, 3)).view(
-1, joint_num, 3, 3)
conf = hand_conf.mean(axis=1,
keepdim=True).float().expand(-1, joint_num)
conf = conf.view(-1, joint_num, 1, 1)
loss = self.loss_smplx_hand_pose(pred_rotmat[target_idxs],
gt_rotmat[target_idxs],
weight=conf[target_idxs])
return loss
def compute_smplx_betas_loss(self, pred_betas: torch.Tensor,
gt_betas: torch.Tensor,
has_smplx_betas: torch.Tensor):
"""Compute loss for smplx betas."""
target_idxs = has_smplx_betas == 1
if gt_betas[target_idxs].numel() == 0:
return torch.Tensor([0]).type_as(gt_betas)
loss = self.loss_smplx_betas(pred_betas[target_idxs],
gt_betas[target_idxs])
loss = loss / gt_betas[target_idxs].shape[0]
return loss
def compute_smplx_betas_prior_loss(self, pred_betas: torch.Tensor):
"""Compute prior loss for smplx betas."""
loss = self.loss_smplx_betas_piror(pred_betas)
return loss
def compute_smplx_expression_loss(self, pred_expression: torch.Tensor,
gt_expression: torch.Tensor,
has_smplx_expression: torch.Tensor,
face_conf: torch.Tensor):
"""Compute loss for smplx betas."""
target_idxs = has_smplx_expression == 1
if gt_expression[target_idxs].numel() == 0:
return torch.Tensor([0]).type_as(gt_expression)
conf = face_conf.mean(axis=1).float()
conf = conf.view(-1, 1)
loss = self.loss_smplx_expression(pred_expression[target_idxs],
gt_expression[target_idxs],
weight=conf[target_idxs])
loss = loss / gt_expression[target_idxs].shape[0]
return loss
def compute_camera_loss(self, cameras: torch.Tensor):
"""Compute loss for predicted camera parameters."""
loss = self.loss_camera(cameras)
return loss
def compute_face_crop_loss(self,
pred_keypoints3d: torch.Tensor,
pred_cam: torch.Tensor,
gt_keypoints2d: torch.Tensor,
face_crop_info: dict,
img_res: Optional[int] = 224,
has_keypoints2d: Optional[torch.Tensor] = None):
"""Compute face crop loss for 2d keypoints."""
keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
if has_keypoints2d is None:
has_keypoints2d = torch.ones((keypoints2d_conf.shape[0]))
if keypoints2d_conf[has_keypoints2d == 1].numel() == 0:
return torch.Tensor([0]).type_as(gt_keypoints2d)
# Expose use weak_perspective_projection
pred_keypoints2d = weak_perspective_projection(
pred_keypoints3d,
scale=pred_cam[:, 0],
translation=pred_cam[:, 1:3])
target_idxs = has_keypoints2d == 1
pred_keypoints2d = pred_keypoints2d[target_idxs]
gt_keypoints2d = gt_keypoints2d[target_idxs]
pred_keypoints2d = (0.5 * pred_keypoints2d + 0.5) * (img_res - 1)
face_inv_crop_transforms = face_crop_info['face_inv_crop_transforms']
pred_keypoints2d_hd = torch.einsum('bij,bkj->bki', [
face_inv_crop_transforms[:, :2, :2], pred_keypoints2d
]) + face_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1)
gt_keypoints2d_hd = torch.einsum('bij,bkj->bki', [
face_inv_crop_transforms[:, :2, :2], gt_keypoints2d
]) + face_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1)
pred_face_keypoints_hd = pred_keypoints2d_hd[:, self.face_idxs]
face_crop_transform = face_crop_info['face_crop_transform']
inv_face_crop_transf = torch.inverse(face_crop_transform)
face_img_keypoints = torch.einsum('bij,bkj->bki', [
inv_face_crop_transf[:, :2, :2], pred_face_keypoints_hd
]) + inv_face_crop_transf[:, :2, 2].unsqueeze(dim=1)
gt_face_keypoints_hd = gt_keypoints2d_hd[:, self.face_idxs]
gt_face_keypoints = torch.einsum('bij,bkj->bki', [
inv_face_crop_transf[:, :2, :2], gt_face_keypoints_hd
]) + inv_face_crop_transf[:, :2, 2].unsqueeze(dim=1)
loss = self.face_crop_loss(
face_img_keypoints,
gt_face_keypoints,
weight=keypoints2d_conf[target_idxs][:, self.face_idxs])
loss /= gt_face_keypoints.shape[0]
return loss
def compute_hand_crop_loss(self,
pred_keypoints3d: torch.Tensor,
pred_cam: torch.Tensor,
gt_keypoints2d: torch.Tensor,
hand_crop_info: dict,
img_res: Optional[int] = 224,
has_keypoints2d: Optional[torch.Tensor] = None):
"""Compute hand crop loss for 2d keypoints."""
keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
if has_keypoints2d is None:
has_keypoints2d = torch.ones((keypoints2d_conf.shape[0]))
if keypoints2d_conf[has_keypoints2d == 1].numel() == 0:
return torch.Tensor([0]).type_as(gt_keypoints2d)
# Expose use weak_perspective_projection
pred_keypoints2d = weak_perspective_projection(
pred_keypoints3d,
scale=pred_cam[:, 0],
translation=pred_cam[:, 1:3])
target_idxs = has_keypoints2d == 1
pred_keypoints2d = pred_keypoints2d[target_idxs]
gt_keypoints2d = gt_keypoints2d[target_idxs]
pred_keypoints2d = (0.5 * pred_keypoints2d + 0.5) * (img_res - 1)
hand_inv_crop_transforms = hand_crop_info['hand_inv_crop_transforms']
pred_keypoints2d_hd = torch.einsum('bij,bkj->bki', [
hand_inv_crop_transforms[:, :2, :2], pred_keypoints2d
]) + hand_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1)
gt_keypoints2d_hd = torch.einsum('bij,bkj->bki', [
hand_inv_crop_transforms[:, :2, :2], gt_keypoints2d
]) + hand_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1)
pred_left_hand_keypoints_hd = pred_keypoints2d_hd[:,
self.left_hand_idxs]
left_hand_crop_transform = hand_crop_info['left_hand_crop_transform']
inv_left_hand_crop_transf = torch.inverse(left_hand_crop_transform)
left_hand_img_keypoints = torch.einsum('bij,bkj->bki', [
inv_left_hand_crop_transf[:, :2, :2], pred_left_hand_keypoints_hd
]) + inv_left_hand_crop_transf[:, :2, 2].unsqueeze(dim=1)
gt_left_hand_keypoints_hd = gt_keypoints2d_hd[:, self.left_hand_idxs]
gt_left_hand_keypoints = torch.einsum('bij,bkj->bki', [
inv_left_hand_crop_transf[:, :2, :2], gt_left_hand_keypoints_hd
]) + inv_left_hand_crop_transf[:, :2, 2].unsqueeze(dim=1)
pred_right_hand_keypoints_hd = pred_keypoints2d_hd[:, self.
right_hand_idxs]
right_hand_crop_transform = hand_crop_info['right_hand_crop_transform']
inv_right_hand_crop_transf = torch.inverse(right_hand_crop_transform)
right_hand_img_keypoints = torch.einsum('bij,bkj->bki', [
inv_right_hand_crop_transf[:, :2, :2], pred_right_hand_keypoints_hd
]) + inv_right_hand_crop_transf[:, :2, 2].unsqueeze(dim=1)
gt_right_hand_keypoints_hd = gt_keypoints2d_hd[:, self.right_hand_idxs]
gt_right_hand_keypoints = torch.einsum('bij,bkj->bki', [
inv_right_hand_crop_transf[:, :2, :2], gt_right_hand_keypoints_hd
]) + inv_right_hand_crop_transf[:, :2, 2].unsqueeze(dim=1)
left_loss = self.hand_crop_loss(
left_hand_img_keypoints,
gt_left_hand_keypoints,
weight=keypoints2d_conf[target_idxs][:, self.left_hand_idxs])
left_loss /= gt_left_hand_keypoints.shape[0]
right_loss = self.hand_crop_loss(
right_hand_img_keypoints,
gt_right_hand_keypoints,
weight=keypoints2d_conf[target_idxs][:, self.right_hand_idxs])
right_loss /= gt_right_hand_keypoints.shape[0]
return left_loss + right_loss
def compute_losses(self, predictions: dict, targets: dict):
"""Compute losses."""
pred_param = predictions['pred_param']
pred_cam = predictions['pred_cam']
gt_keypoints3d = targets['keypoints3d']
gt_keypoints2d = targets['keypoints2d']
if self.body_model_train is not None:
pred_output = self.body_model_train(**pred_param)
pred_keypoints3d = pred_output['joints']
if 'has_keypoints3d' in targets:
has_keypoints3d = targets['has_keypoints3d'].squeeze(-1)
else:
has_keypoints3d = None
if 'has_keypoints2d' in targets:
has_keypoints2d = targets['has_keypoints2d'].squeeze(-1)
else:
has_keypoints2d = None
losses = {}
if self.loss_keypoints3d is not None:
losses['keypoints3d_loss'] = self.compute_keypoints3d_loss(
pred_keypoints3d,
gt_keypoints3d,
has_keypoints3d=has_keypoints3d)
if self.loss_keypoints2d is not None:
losses['keypoints2d_loss'] = self.compute_keypoints2d_loss(
pred_keypoints3d,
pred_cam,
gt_keypoints2d,
img_res=targets['img'].shape[-1],
has_keypoints2d=has_keypoints2d)
if self.loss_smplx_global_orient is not None:
pred_global_orient = pred_param['global_orient']
pred_global_orient = pose2rotmat(pred_global_orient)
gt_global_orient = targets['smplx_global_orient']
has_smplx_global_orient = targets[
'has_smplx_global_orient'].squeeze(-1)
losses['smplx_global_orient_loss'] = \
self.compute_smplx_global_orient_loss(
pred_global_orient, gt_global_orient,
has_smplx_global_orient)
if self.loss_smplx_body_pose is not None:
pred_pose = pred_param['body_pose']
pred_pose = pose2rotmat(pred_pose)
gt_pose = targets['smplx_body_pose']
has_smplx_body_pose = targets['has_smplx_body_pose'].squeeze(-1)
losses['smplx_body_pose_loss'] = \
self.compute_smplx_body_pose_loss(
pred_pose, gt_pose, has_smplx_body_pose)
if self.loss_smplx_jaw_pose is not None:
pred_jaw_pose = pred_param['jaw_pose']
pred_jaw_pose = pose2rotmat(pred_jaw_pose)
gt_jaw_pose = targets['smplx_jaw_pose']
face_conf = get_keypoint_idxs_by_part('head', self.convention)
has_smplx_jaw_pose = targets['has_smplx_jaw_pose'].squeeze(-1)
losses['smplx_jaw_pose_loss'] = self.compute_smplx_jaw_pose_loss(
pred_jaw_pose, gt_jaw_pose, has_smplx_jaw_pose,
gt_keypoints2d[:, face_conf, 2])
if self.loss_smplx_hand_pose is not None:
pred_right_hand_pose = pred_param['right_hand_pose']
pred_right_hand_pose = pose2rotmat(pred_right_hand_pose)
gt_right_hand_pose = targets['smplx_right_hand_pose']
right_hand_conf = get_keypoint_idxs_by_part(
'right_hand', self.convention)
has_smplx_right_hand_pose = targets[
'has_smplx_right_hand_pose'].squeeze(-1)
losses['smplx_right_hand_pose_loss'] = \
self.compute_smplx_hand_pose_loss(
pred_right_hand_pose, gt_right_hand_pose,
has_smplx_right_hand_pose,
gt_keypoints2d[:, right_hand_conf, 2])
if 'left_hand_pose' in pred_param:
pred_left_hand_pose = pred_param['left_hand_pose']
pred_left_hand_pose = pose2rotmat(pred_left_hand_pose)
gt_left_hand_pose = targets['smplx_left_hand_pose']
left_hand_conf = get_keypoint_idxs_by_part(
'left_hand', self.convention)
has_smplx_left_hand_pose = targets[
'has_smplx_left_hand_pose'].squeeze(-1)
losses['smplx_left_hand_pose_loss'] = \
self.compute_smplx_hand_pose_loss(
pred_left_hand_pose, gt_left_hand_pose,
has_smplx_left_hand_pose,
gt_keypoints2d[:, left_hand_conf, 2])
if self.loss_smplx_betas is not None:
pred_betas = pred_param['betas']
gt_betas = targets['smplx_betas']
has_smplx_betas = targets['has_smplx_betas'].squeeze(-1)
losses['smplx_betas_loss'] = self.compute_smplx_betas_loss(
pred_betas, gt_betas, has_smplx_betas)
if self.loss_smplx_expression is not None:
pred_expression = pred_param['expression']
gt_expression = targets['smplx_expression']
face_conf = get_keypoint_idxs_by_part('head', self.convention)
has_smplx_expression = targets['has_smplx_expression'].squeeze(-1)
losses[
'smplx_expression_loss'] = self.compute_smplx_expression_loss(
pred_expression, gt_expression, has_smplx_expression,
gt_keypoints2d[:, face_conf, 2])
if self.loss_smplx_betas_piror is not None:
pred_betas = pred_param['betas']
losses['smplx_betas_prior_loss'] = \
self.compute_smplx_betas_prior_loss(
pred_betas)
if self.loss_camera is not None:
losses['camera_loss'] = self.compute_camera_loss(pred_cam)
if self.apply_hand_model and self.hand_crop_loss is not None:
losses['hand_crop_loss'] = self.compute_hand_crop_loss(
pred_keypoints3d, pred_cam, gt_keypoints2d,
predictions['hand_crop_info'], targets['img'].shape[-1],
has_keypoints2d)
if self.apply_face_model and self.face_crop_loss is not None:
losses['face_crop_loss'] = self.compute_face_crop_loss(
pred_keypoints3d, pred_cam, gt_keypoints2d,
predictions['face_crop_info'], targets['img'].shape[-1],
has_keypoints2d)
return losses
@abstractmethod
def prepare_targets(self, data_batch):
pass
def forward_train(self, **kwargs):
"""Forward function for general training.
For mesh estimation, we do not use this interface.
"""
raise NotImplementedError('This interface should not be used in '
'current training schedule. Please use '
'`train_step` for training.')
@abstractmethod
def forward_test(self, img, img_metas, **kwargs):
"""Defines the computation performed at every call when testing."""
pass
class SMPLXImageBodyModelEstimator(SMPLXBodyModelEstimator):
def prepare_targets(self, data_batch: dict):
# Image Mesh Estimator does not need extra process for ground truth
return data_batch
def forward_test(self, img: torch.Tensor, img_metas: dict, **kwargs):
"""Defines the computation performed at every call when testing."""
if self.backbone is not None:
features = self.backbone(img)
else:
features = kwargs['features']
if self.neck is not None:
features = self.neck(features)
predictions = self.head(features)
if self.apply_hand_model:
hand_input_img, hand_mean, hand_crop_info = self.crop_hand_func(
predictions, img_metas)
hand_features = self.hand_backbone(hand_input_img)
if self.neck is not None:
hand_features = self.hand_neck(hand_features)
hand_predictions = self.hand_head(hand_features, cond=hand_mean)
predictions = self.hand_merge_func(predictions, hand_predictions)
predictions['hand_crop_info'] = hand_crop_info
if self.apply_face_model:
face_input_img, face_mean, face_crop_info = self.crop_face_func(
predictions, img_metas)
face_features = self.face_backbone(face_input_img)
if self.neck is not None:
face_features = self.face_neck(face_features)
face_predictions = self.face_head(face_features, cond=face_mean)
predictions = self.face_merge_func(predictions, face_predictions)
predictions['face_crop_info'] = face_crop_info
pred_param = predictions['pred_param']
pred_cam = predictions['pred_cam']
pred_output = self.body_model_test(**pred_param)
pred_vertices = pred_output['vertices']
pred_keypoints_3d = pred_output['joints']
all_preds = {}
all_preds['keypoints_3d'] = pred_keypoints_3d.detach().cpu().numpy()
for value in pred_param.values():
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
all_preds['param'] = pred_param
all_preds['camera'] = pred_cam.detach().cpu().numpy()
all_preds['vertices'] = pred_vertices.detach().cpu().numpy()
image_path = []
for img_meta in img_metas:
image_path.append(img_meta['image_path'])
all_preds['image_path'] = image_path
all_preds['image_idx'] = kwargs['sample_idx']
return all_preds
class FrozenBatchNorm2d(nn.Module):
"""BatchNorm2d where the batch statistics and the affine parameters are
fixed."""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer('weight', torch.ones(n))
self.register_buffer('bias', torch.zeros(n))
self.register_buffer('running_mean', torch.zeros(n))
self.register_buffer('running_var', torch.ones(n))
@staticmethod
def from_bn(module: nn.BatchNorm2d):
"""Initializes a frozen batch norm module from a batch norm module."""
dim = len(module.weight.data)
frozen_module = FrozenBatchNorm2d(dim)
frozen_module.weight.data = module.weight.data
missing, not_found = frozen_module.load_state_dict(module.state_dict(),
strict=False)
return frozen_module
@classmethod
def convert_frozen_batchnorm(cls, module):
"""Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
Args:
module (torch.nn.Module):
Returns:
If module is BatchNorm/SyncBatchNorm, returns a new module.
Otherwise, in-place convert module and return it.
Similar to convert_sync_batchnorm in
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
"""
bn_module = nn.modules.batchnorm
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
res = module
if isinstance(module, bn_module):
res = cls(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = cls.convert_frozen_batchnorm(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def forward(self, x):
# Cast all fixed parameters to half() if necessary
if x.dtype == torch.float16:
self.weight = self.weight.half()
self.bias = self.bias.half()
self.running_mean = self.running_mean.half()
self.running_var = self.running_var.half()
return F.batch_norm(x, self.running_mean, self.running_var,
self.weight, self.bias, False)