from abc import ABCMeta, abstractmethod from typing import Optional, Tuple, Union import torch import torch.nn.functional as F import detrsmpl.core.visualization.visualize_smpl as visualize_smpl from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idx from detrsmpl.models.utils import FitsDict from detrsmpl.utils.geometry import ( batch_rodrigues, estimate_translation, project_points, rotation_matrix_to_angle_axis, ) from ..backbones.builder import build_backbone from ..body_models.builder import build_body_model from ..discriminators.builder import build_discriminator from ..heads.builder import build_head from ..losses.builder import build_loss from ..necks.builder import build_neck from ..registrants.builder import build_registrant from .base_architecture import BaseArchitecture def set_requires_grad(nets, requires_grad=False): """Set requies_grad for all the networks. Args: nets (nn.Module | list[nn.Module]): A list of networks or a single network. requires_grad (bool): Whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad class BodyModelEstimator(BaseArchitecture, metaclass=ABCMeta): """BodyModelEstimator Architecture. Args: backbone (dict | None, optional): Backbone config dict. Default: None. neck (dict | None, optional): Neck config dict. Default: None head (dict | None, optional): Regressor config dict. Default: None. disc (dict | None, optional): Discriminator config dict. Default: None. registration (dict | None, optional): Registration config dict. Default: None. body_model_train (dict | None, optional): SMPL config dict during training. Default: None. body_model_test (dict | None, optional): SMPL config dict during test. Default: None. convention (str, optional): Keypoints convention. Default: "human_data" loss_keypoints2d (dict | None, optional): Losses config dict for 2D keypoints. Default: None. loss_keypoints3d (dict | None, optional): Losses config dict for 3D keypoints. Default: None. loss_vertex (dict | None, optional): Losses config dict for mesh vertices. Default: None loss_smpl_pose (dict | None, optional): Losses config dict for smpl pose. Default: None loss_smpl_betas (dict | None, optional): Losses config dict for smpl betas. Default: None loss_camera (dict | None, optional): Losses config dict for predicted camera parameters. Default: None loss_adv (dict | None, optional): Losses config for adversial training. Default: None. loss_segm_mask (dict | None, optional): Losses config for predicted part segmentation. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, backbone: Optional[Union[dict, None]] = None, neck: Optional[Union[dict, None]] = None, head: Optional[Union[dict, None]] = None, disc: Optional[Union[dict, None]] = None, registration: Optional[Union[dict, None]] = None, body_model_train: Optional[Union[dict, None]] = None, body_model_test: Optional[Union[dict, None]] = None, convention: Optional[str] = 'human_data', loss_keypoints2d: Optional[Union[dict, None]] = None, loss_keypoints3d: Optional[Union[dict, None]] = None, loss_vertex: Optional[Union[dict, None]] = None, loss_smpl_pose: Optional[Union[dict, None]] = None, loss_smpl_betas: Optional[Union[dict, None]] = None, loss_camera: Optional[Union[dict, None]] = None, loss_adv: Optional[Union[dict, None]] = None, loss_segm_mask: Optional[Union[dict, None]] = None, init_cfg: Optional[Union[list, dict, None]] = None): super(BodyModelEstimator, self).__init__(init_cfg) self.backbone = build_backbone(backbone) self.neck = build_neck(neck) self.head = build_head(head) self.disc = build_discriminator(disc) self.body_model_train = build_body_model(body_model_train) self.body_model_test = build_body_model(body_model_test) self.convention = convention # TODO: support HMR+ self.registration = registration if registration is not None: self.fits_dict = FitsDict(fits='static') self.registration_mode = self.registration['mode'] self.registrant = build_registrant(registration['registrant']) else: self.registrant = None self.loss_keypoints2d = build_loss(loss_keypoints2d) self.loss_keypoints3d = build_loss(loss_keypoints3d) self.loss_vertex = build_loss(loss_vertex) self.loss_smpl_pose = build_loss(loss_smpl_pose) self.loss_smpl_betas = build_loss(loss_smpl_betas) self.loss_adv = build_loss(loss_adv) self.loss_camera = build_loss(loss_camera) self.loss_segm_mask = build_loss(loss_segm_mask) set_requires_grad(self.body_model_train, False) set_requires_grad(self.body_model_test, False) def train_step(self, data_batch, optimizer, **kwargs): """Train step function. In this function, the detector will finish the train step following the pipeline: 1. get fake and real SMPL parameters 2. optimize discriminator (if have) 3. optimize generator If `self.train_cfg.disc_step > 1`, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing generator after `disc_step` iterations for discriminator. Args: data_batch (torch.Tensor): Batch of data as input. optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for generator and discriminator (if have). Returns: outputs (dict): Dict with loss, information for logger, the number of samples. """ if self.backbone is not None: img = data_batch['img'] features = self.backbone(img) else: features = data_batch['features'] if self.neck is not None: features = self.neck(features) predictions = self.head(features) targets = self.prepare_targets(data_batch) # optimize discriminator (if have) if self.disc is not None: self.optimize_discrinimator(predictions, data_batch, optimizer) if self.registration is not None: targets = self.run_registration(predictions, targets) losses = self.compute_losses(predictions, targets) # optimizer generator part if self.disc is not None: adv_loss = self.optimize_generator(predictions) losses.update(adv_loss) loss, log_vars = self._parse_losses(losses) for key in optimizer.keys(): optimizer[key].zero_grad() loss.backward() for key in optimizer.keys(): optimizer[key].step() outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(next(iter(data_batch.values())))) return outputs def run_registration( self, predictions: dict, targets: dict, threshold: Optional[float] = 10.0, focal_length: Optional[float] = 5000.0, img_res: Optional[Union[Tuple[int], int]] = 224) -> dict: """Run registration on 2D keypoinst in predictions to obtain SMPL parameters as pseudo ground truth. Args: predictions (dict): predicted SMPL parameters are used for initialization. targets (dict): existing ground truths with 2D keypoints threshold (float, optional): the threshold to update fits dictionary. Default: 10.0. focal_length (tuple(int) | int, optional): camera focal_length img_res (int, optional): image resolution Returns: targets: contains additional SMPL parameters """ img_metas = targets['img_metas'] dataset_name = [meta['dataset_name'] for meta in img_metas ] # name of the dataset the image comes from indices = targets['sample_idx'].squeeze() is_flipped = targets['is_flipped'].squeeze().bool( ) # flag that indicates whether image was flipped # during data augmentation rot_angle = targets['rotation'].squeeze( ) # rotation angle used for data augmentation Q gt_betas = targets['smpl_betas'].float() gt_global_orient = targets['smpl_global_orient'].float() gt_pose = targets['smpl_body_pose'].float().view(-1, 69) pred_rotmat = predictions['pred_pose'].detach().clone() pred_betas = predictions['pred_shape'].detach().clone() pred_cam = predictions['pred_cam'].detach().clone() pred_cam_t = torch.stack([ pred_cam[:, 1], pred_cam[:, 2], 2 * focal_length / (img_res * pred_cam[:, 0] + 1e-9) ], dim=-1) gt_keypoints_2d = targets['keypoints2d'].float() num_keypoints = gt_keypoints_2d.shape[1] has_smpl = targets['has_smpl'].view( -1).bool() # flag that indicates whether SMPL parameters are valid batch_size = has_smpl.shape[0] device = has_smpl.device # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as # it comes from SMPL gt_out = self.body_model_train(betas=gt_betas, body_pose=gt_pose, global_orient=gt_global_orient) # TODO: support more convention assert num_keypoints == 49 gt_model_joints = gt_out['joints'] gt_vertices = gt_out['vertices'] # Get current best fits from the dictionary opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = opt_pose.to(device) opt_betas = opt_betas.to(device) opt_output = self.body_model_train(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_joints = opt_output['joints'] opt_vertices = opt_output['vertices'] gt_keypoints_2d_orig = gt_keypoints_2d.clone() # Estimate camera translation given the model joints and 2D keypoints # by minimizing a weighted least squares loss gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=focal_length, img_size=img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=focal_length, img_size=img_res) with torch.no_grad(): loss_dict = self.registrant.evaluate( global_orient=opt_pose[:, :3], body_pose=opt_pose[:, 3:], betas=opt_betas, transl=opt_cam_t, keypoints2d=gt_keypoints_2d_orig[:, :, :2], keypoints2d_conf=gt_keypoints_2d_orig[:, :, 2], reduction_override='none') opt_joint_loss = loss_dict['keypoint2d_loss'].sum(dim=-1).sum(dim=-1) if self.registration_mode == 'in_the_loop': # Convert predicted rotation matrices to axis-angle pred_rotmat_hom = torch.cat([ pred_rotmat.detach().view(-1, 3, 3).detach(), torch.tensor([0, 0, 1], dtype=torch.float32, device=device).view(1, 3, 1).expand( batch_size * 24, -1, -1) ], dim=-1) pred_pose = rotation_matrix_to_angle_axis( pred_rotmat_hom).contiguous().view(batch_size, -1) # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, # so manually hack it pred_pose[torch.isnan(pred_pose)] = 0.0 registrant_output = self.registrant( keypoints2d=gt_keypoints_2d_orig[:, :, :2], keypoints2d_conf=gt_keypoints_2d_orig[:, :, 2], init_global_orient=pred_pose[:, :3], init_transl=pred_cam_t, init_body_pose=pred_pose[:, 3:], init_betas=pred_betas, return_joints=True, return_verts=True, return_losses=True) new_opt_vertices = registrant_output[ 'vertices'] - pred_cam_t.unsqueeze(1) new_opt_joints = registrant_output[ 'joints'] - pred_cam_t.unsqueeze(1) new_opt_global_orient = registrant_output['global_orient'] new_opt_body_pose = registrant_output['body_pose'] new_opt_pose = torch.cat( [new_opt_global_orient, new_opt_body_pose], dim=1) new_opt_betas = registrant_output['betas'] new_opt_cam_t = registrant_output['transl'] new_opt_joint_loss = registrant_output['keypoint2d_loss'].sum( dim=-1).sum(dim=-1) # Will update the dictionary for the examples where the new loss # is less than the current one update = (new_opt_joint_loss < opt_joint_loss) opt_joint_loss[update] = new_opt_joint_loss[update] opt_vertices[update, :] = new_opt_vertices[update, :] opt_joints[update, :] = new_opt_joints[update, :] opt_pose[update, :] = new_opt_pose[update, :] opt_betas[update, :] = new_opt_betas[update, :] opt_cam_t[update, :] = new_opt_cam_t[update, :] self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu(), update.cpu())] = (opt_pose.cpu(), opt_betas.cpu()) # Replace extreme betas with zero betas opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0. # Replace the optimized parameters with the ground truth parameters, # if available opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :] opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :] opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :] opt_pose[has_smpl, 3:] = gt_pose[has_smpl, :] opt_pose[has_smpl, :3] = gt_global_orient[has_smpl, :] opt_betas[has_smpl, :] = gt_betas[has_smpl, :] # Assert whether a fit is valid by comparing the joint loss with # the threshold valid_fit = (opt_joint_loss < threshold).to(device) valid_fit = valid_fit | has_smpl targets['valid_fit'] = valid_fit targets['opt_vertices'] = opt_vertices targets['opt_cam_t'] = opt_cam_t targets['opt_joints'] = opt_joints targets['opt_pose'] = opt_pose targets['opt_betas'] = opt_betas return targets def optimize_discrinimator(self, predictions: dict, data_batch: dict, optimizer: dict): """Optimize discrinimator during adversarial training.""" set_requires_grad(self.disc, True) fake_data = self.make_fake_data(predictions, requires_grad=False) real_data = self.make_real_data(data_batch) fake_score = self.disc(fake_data) real_score = self.disc(real_data) disc_losses = {} disc_losses['real_loss'] = self.loss_adv(real_score, target_is_real=True, is_disc=True) disc_losses['fake_loss'] = self.loss_adv(fake_score, target_is_real=False, is_disc=True) loss_disc, log_vars_d = self._parse_losses(disc_losses) optimizer['disc'].zero_grad() loss_disc.backward() optimizer['disc'].step() def optimize_generator(self, predictions: dict): """Optimize generator during adversarial training.""" set_requires_grad(self.disc, False) fake_data = self.make_fake_data(predictions, requires_grad=True) pred_score = self.disc(fake_data) loss_adv = self.loss_adv(pred_score, target_is_real=True, is_disc=False) loss = dict(adv_loss=loss_adv) return loss def compute_keypoints3d_loss( self, pred_keypoints3d: torch.Tensor, gt_keypoints3d: torch.Tensor, has_keypoints3d: Optional[torch.Tensor] = None): """Compute loss for 3d keypoints.""" keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1) keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3) pred_keypoints3d = pred_keypoints3d.float() gt_keypoints3d = gt_keypoints3d[:, :, :3].float() # currently, only mpi_inf_3dhp and h36m have 3d keypoints # both datasets have right_hip_extra and left_hip_extra right_hip_idx = get_keypoint_idx('right_hip_extra', self.convention) left_hip_idx = get_keypoint_idx('left_hip_extra', self.convention) gt_pelvis = (gt_keypoints3d[:, right_hip_idx, :] + gt_keypoints3d[:, left_hip_idx, :]) / 2 pred_pelvis = (pred_keypoints3d[:, right_hip_idx, :] + pred_keypoints3d[:, left_hip_idx, :]) / 2 gt_keypoints3d = gt_keypoints3d - gt_pelvis[:, None, :] pred_keypoints3d = pred_keypoints3d - pred_pelvis[:, None, :] loss = self.loss_keypoints3d(pred_keypoints3d, gt_keypoints3d, reduction_override='none') # If has_keypoints3d is not None, then computes the losses on the # instances that have ground-truth keypoints3d. # But the zero confidence keypoints will be included in mean. # Otherwise, only compute the keypoints3d # which have positive confidence. # has_keypoints3d is None when the key has_keypoints3d # is not in the datasets if has_keypoints3d is None: valid_pos = keypoints3d_conf > 0 if keypoints3d_conf[valid_pos].numel() == 0: return torch.Tensor([0]).type_as(gt_keypoints3d) loss = torch.sum(loss * keypoints3d_conf) loss /= keypoints3d_conf[valid_pos].numel() else: keypoints3d_conf = keypoints3d_conf[has_keypoints3d == 1] if keypoints3d_conf.shape[0] == 0: return torch.Tensor([0]).type_as(gt_keypoints3d) loss = loss[has_keypoints3d == 1] loss = (loss * keypoints3d_conf).mean() return loss def compute_keypoints2d_loss( self, pred_keypoints3d: torch.Tensor, pred_cam: torch.Tensor, gt_keypoints2d: torch.Tensor, img_res: Optional[int] = 224, focal_length: Optional[int] = 5000, has_keypoints2d: Optional[torch.Tensor] = None): """Compute loss for 2d keypoints.""" keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1) keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) gt_keypoints2d = gt_keypoints2d[:, :, :2].float() pred_keypoints2d = project_points(pred_keypoints3d, pred_cam, focal_length=focal_length, img_res=img_res) # Normalize keypoints to [-1,1] # The coordinate origin of pred_keypoints_2d is # the center of the input image. pred_keypoints2d = 2 * pred_keypoints2d / (img_res - 1) # The coordinate origin of gt_keypoints_2d is # the top left corner of the input image. gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1 loss = self.loss_keypoints2d(pred_keypoints2d, gt_keypoints2d, reduction_override='none') # If has_keypoints2d is not None, then computes the losses on the # instances that have ground-truth keypoints2d. # But the zero confidence keypoints will be included in mean. # Otherwise, only compute the keypoints2d # which have positive confidence. # has_keypoints2d is None when the key has_keypoints2d # is not in the datasets if has_keypoints2d is None: valid_pos = keypoints2d_conf > 0 if keypoints2d_conf[valid_pos].numel() == 0: return torch.Tensor([0]).type_as(gt_keypoints2d) loss = torch.sum(loss * keypoints2d_conf) loss /= keypoints2d_conf[valid_pos].numel() else: keypoints2d_conf = keypoints2d_conf[has_keypoints2d == 1] if keypoints2d_conf.shape[0] == 0: return torch.Tensor([0]).type_as(gt_keypoints2d) loss = loss[has_keypoints2d == 1] loss = (loss * keypoints2d_conf).mean() return loss def compute_vertex_loss(self, pred_vertices: torch.Tensor, gt_vertices: torch.Tensor, has_smpl: torch.Tensor): """Compute loss for vertices.""" gt_vertices = gt_vertices.float() conf = has_smpl.float().view(-1, 1, 1) conf = conf.repeat(1, gt_vertices.shape[1], gt_vertices.shape[2]) loss = self.loss_vertex(pred_vertices, gt_vertices, reduction_override='none') valid_pos = conf > 0 if conf[valid_pos].numel() == 0: return torch.Tensor([0]).type_as(gt_vertices) loss = torch.sum(loss * conf) / conf[valid_pos].numel() return loss def compute_smpl_pose_loss(self, pred_rotmat: torch.Tensor, gt_pose: torch.Tensor, has_smpl: torch.Tensor): """Compute loss for smpl pose.""" conf = has_smpl.float().view(-1) valid_pos = conf > 0 if conf[valid_pos].numel() == 0: return torch.Tensor([0]).type_as(gt_pose) pred_rotmat = pred_rotmat[valid_pos] gt_pose = gt_pose[valid_pos] conf = conf[valid_pos] gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3) loss = self.loss_smpl_pose(pred_rotmat, gt_rotmat, reduction_override='none') loss = loss.view(loss.shape[0], -1).mean(-1) loss = torch.mean(loss * conf) return loss def compute_smpl_betas_loss(self, pred_betas: torch.Tensor, gt_betas: torch.Tensor, has_smpl: torch.Tensor): """Compute loss for smpl betas.""" conf = has_smpl.float().view(-1) valid_pos = conf > 0 if conf[valid_pos].numel() == 0: return torch.Tensor([0]).type_as(gt_betas) pred_betas = pred_betas[valid_pos] gt_betas = gt_betas[valid_pos] conf = conf[valid_pos] loss = self.loss_smpl_betas(pred_betas, gt_betas, reduction_override='none') loss = loss.view(loss.shape[0], -1).mean(-1) loss = torch.mean(loss * conf) return loss def compute_camera_loss(self, cameras: torch.Tensor): """Compute loss for predicted camera parameters.""" loss = self.loss_camera(cameras) return loss def compute_part_segmentation_loss(self, pred_heatmap: torch.Tensor, gt_vertices: torch.Tensor, gt_keypoints2d: torch.Tensor, gt_model_joints: torch.Tensor, has_smpl: torch.Tensor, img_res: Optional[int] = 224, focal_length: Optional[int] = 500): """Compute loss for part segmentations.""" device = gt_keypoints2d.device gt_keypoints2d_valid = gt_keypoints2d[has_smpl == 1] batch_size = gt_keypoints2d_valid.shape[0] gt_vertices_valid = gt_vertices[has_smpl == 1] gt_model_joints_valid = gt_model_joints[has_smpl == 1] if batch_size == 0: return torch.Tensor([0]).type_as(gt_keypoints2d) gt_cam_t = estimate_translation( gt_model_joints_valid, gt_keypoints2d_valid, focal_length=focal_length, img_size=img_res, ) K = torch.eye(3) K[0, 0] = focal_length K[1, 1] = focal_length K[2, 2] = 1 K[0, 2] = img_res / 2. K[1, 2] = img_res / 2. K = K[None, :, :] R = torch.eye(3)[None, :, :] device = gt_keypoints2d.device gt_sem_mask = visualize_smpl.render_smpl( verts=gt_vertices_valid, R=R, K=K, T=gt_cam_t, render_choice='part_silhouette', resolution=img_res, return_tensor=True, body_model=self.body_model_train, device=device, in_ndc=False, convention='pytorch3d', projection='perspective', no_grad=True, batch_size=batch_size, verbose=False, ) gt_sem_mask = torch.flip(gt_sem_mask, [1, 2]).squeeze(-1).detach() pred_heatmap_valid = pred_heatmap[has_smpl == 1] ph, pw = pred_heatmap_valid.size(2), pred_heatmap_valid.size(3) h, w = gt_sem_mask.size(1), gt_sem_mask.size(2) if ph != h or pw != w: pred_heatmap_valid = F.interpolate(input=pred_heatmap_valid, size=(h, w), mode='bilinear') loss = self.loss_segm_mask(pred_heatmap_valid, gt_sem_mask) return loss def compute_losses(self, predictions: dict, targets: dict): """Compute losses.""" pred_betas = predictions['pred_shape'].view(-1, 10) pred_pose = predictions['pred_pose'].view(-1, 24, 3, 3) pred_cam = predictions['pred_cam'].view(-1, 3) gt_keypoints3d = targets['keypoints3d'] gt_keypoints2d = targets['keypoints2d'] # pred_pose N, 24, 3, 3 if self.body_model_train is not None: pred_output = self.body_model_train( betas=pred_betas, body_pose=pred_pose[:, 1:], global_orient=pred_pose[:, 0].unsqueeze(1), pose2rot=False, num_joints=gt_keypoints2d.shape[1]) pred_keypoints3d = pred_output['joints'] pred_vertices = pred_output['vertices'] # # TODO: temp. Should we multiply confs here? # pred_keypoints3d_mask = pred_output['joint_mask'] # keypoints3d_mask = keypoints3d_mask * pred_keypoints3d_mask # TODO: temp solution if 'valid_fit' in targets: has_smpl = targets['valid_fit'].view(-1) # global_orient = targets['opt_pose'][:, :3].view(-1, 1, 3) gt_pose = targets['opt_pose'] gt_betas = targets['opt_betas'] gt_vertices = targets['opt_vertices'] else: has_smpl = targets['has_smpl'].view(-1) gt_pose = targets['smpl_body_pose'] global_orient = targets['smpl_global_orient'].view(-1, 1, 3) gt_pose = torch.cat((global_orient, gt_pose), dim=1).float() gt_betas = targets['smpl_betas'].float() # gt_pose N, 72 if self.body_model_train is not None: gt_output = self.body_model_train( betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3], num_joints=gt_keypoints2d.shape[1]) gt_vertices = gt_output['vertices'] gt_model_joints = gt_output['joints'] if 'has_keypoints3d' in targets: has_keypoints3d = targets['has_keypoints3d'].squeeze(-1) else: has_keypoints3d = None if 'has_keypoints2d' in targets: has_keypoints2d = targets['has_keypoints2d'].squeeze(-1) else: has_keypoints2d = None if 'pred_segm_mask' in predictions: pred_segm_mask = predictions['pred_segm_mask'] losses = {} if self.loss_keypoints3d is not None: losses['keypoints3d_loss'] = self.compute_keypoints3d_loss( pred_keypoints3d, gt_keypoints3d, has_keypoints3d=has_keypoints3d) if self.loss_keypoints2d is not None: losses['keypoints2d_loss'] = self.compute_keypoints2d_loss( pred_keypoints3d, pred_cam, gt_keypoints2d, has_keypoints2d=has_keypoints2d) if self.loss_vertex is not None: losses['vertex_loss'] = self.compute_vertex_loss( pred_vertices, gt_vertices, has_smpl) if self.loss_smpl_pose is not None: losses['smpl_pose_loss'] = self.compute_smpl_pose_loss( pred_pose, gt_pose, has_smpl) if self.loss_smpl_betas is not None: losses['smpl_betas_loss'] = self.compute_smpl_betas_loss( pred_betas, gt_betas, has_smpl) if self.loss_camera is not None: losses['camera_loss'] = self.compute_camera_loss(pred_cam) if self.loss_segm_mask is not None: losses['loss_segm_mask'] = self.compute_part_segmentation_loss( pred_segm_mask, gt_vertices, gt_keypoints2d, gt_model_joints, has_smpl) return losses @abstractmethod def make_fake_data(self, predictions, requires_grad): pass @abstractmethod def make_real_data(self, data_batch): pass @abstractmethod def prepare_targets(self, data_batch): pass def forward_train(self, **kwargs): """Forward function for general training. For mesh estimation, we do not use this interface. """ raise NotImplementedError('This interface should not be used in ' 'current training schedule. Please use ' '`train_step` for training.') @abstractmethod def forward_test(self, img, img_metas, **kwargs): """Defines the computation performed at every call when testing.""" pass class ImageBodyModelEstimator(BodyModelEstimator): def make_fake_data(self, predictions: dict, requires_grad: bool): pred_cam = predictions['pred_cam'] pred_pose = predictions['pred_pose'] pred_betas = predictions['pred_shape'] if requires_grad: fake_data = (pred_cam, pred_pose, pred_betas) else: fake_data = (pred_cam.detach(), pred_pose.detach(), pred_betas.detach()) return fake_data def make_real_data(self, data_batch: dict): transl = data_batch['adv_smpl_transl'].float() global_orient = data_batch['adv_smpl_global_orient'] body_pose = data_batch['adv_smpl_body_pose'] betas = data_batch['adv_smpl_betas'].float() pose = torch.cat((global_orient, body_pose), dim=-1).float() real_data = (transl, pose, betas) return real_data def prepare_targets(self, data_batch: dict): # Image Mesh Estimator does not need extra process for ground truth return data_batch def forward_test(self, img: torch.Tensor, img_metas: dict, **kwargs): """Defines the computation performed at every call when testing.""" if self.backbone is not None: features = self.backbone(img) else: features = kwargs['features'] if self.neck is not None: features = self.neck(features) predictions = self.head(features) pred_pose = predictions['pred_pose'] pred_betas = predictions['pred_shape'] pred_cam = predictions['pred_cam'] pred_output = self.body_model_test( betas=pred_betas, body_pose=pred_pose[:, 1:], global_orient=pred_pose[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output['vertices'] pred_keypoints_3d = pred_output['joints'] all_preds = {} all_preds['keypoints_3d'] = pred_keypoints_3d.detach().cpu().numpy() all_preds['smpl_pose'] = pred_pose.detach().cpu().numpy() all_preds['smpl_beta'] = pred_betas.detach().cpu().numpy() all_preds['camera'] = pred_cam.detach().cpu().numpy() all_preds['vertices'] = pred_vertices.detach().cpu().numpy() image_path = [] for img_meta in img_metas: image_path.append(img_meta['image_path']) all_preds['image_path'] = image_path all_preds['image_idx'] = kwargs['sample_idx'] return all_preds class VideoBodyModelEstimator(BodyModelEstimator): def make_fake_data(self, predictions: dict, requires_grad: bool): B, T = predictions['pred_cam'].shape[:2] pred_cam_vec = predictions['pred_cam'] pred_betas_vec = predictions['pred_shape'] pred_pose = predictions['pred_pose'] pred_pose_vec = rotation_matrix_to_angle_axis(pred_pose.view(-1, 3, 3)) pred_pose_vec = pred_pose_vec.contiguous().view(B, T, -1) pred_theta_vec = (pred_cam_vec, pred_pose_vec, pred_betas_vec) pred_theta_vec = torch.cat(pred_theta_vec, dim=-1) if not requires_grad: pred_theta_vec = pred_theta_vec.detach() return pred_theta_vec[:, :, 6:75] def make_real_data(self, data_batch: dict): B, T = data_batch['adv_smpl_transl'].shape[:2] transl = data_batch['adv_smpl_transl'].view(B, T, -1) global_orient = \ data_batch['adv_smpl_global_orient'].view(B, T, -1) body_pose = data_batch['adv_smpl_body_pose'].view(B, T, -1) betas = data_batch['adv_smpl_betas'].view(B, T, -1) real_data = (transl, global_orient, body_pose, betas) real_data = torch.cat(real_data, dim=-1).float() return real_data[:, :, 6:75] def prepare_targets(self, data_batch: dict): # Video Mesh Estimator needs squeeze first two dimensions B, T = data_batch['smpl_body_pose'].shape[:2] output = { 'smpl_body_pose': data_batch['smpl_body_pose'].view(-1, 23, 3), 'smpl_global_orient': data_batch['smpl_global_orient'].view(-1, 3), 'smpl_betas': data_batch['smpl_betas'].view(-1, 10), 'has_smpl': data_batch['has_smpl'].view(-1), 'keypoints3d': data_batch['keypoints3d'].view(B * T, -1, 4), 'keypoints2d': data_batch['keypoints2d'].view(B * T, -1, 3) } return output def forward_test(self, img_metas: dict, **kwargs): """Defines the computation performed at every call when testing.""" if self.backbone is not None: features = self.backbone(kwargs['img']) else: features = kwargs['features'] if self.neck is not None: features = self.neck(features) B, T = features.shape[:2] predictions = self.head(features) pred_pose = predictions['pred_pose'].view(-1, 24, 3, 3) pred_betas = predictions['pred_shape'].view(-1, 10) pred_cam = predictions['pred_cam'].view(-1, 3) pred_output = self.body_model_test( betas=pred_betas, body_pose=pred_pose[:, 1:], global_orient=pred_pose[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output['vertices'] pred_keypoints_3d = pred_output['joints'] all_preds = {} all_preds['keypoints_3d'] = pred_keypoints_3d.detach().cpu().numpy() all_preds['smpl_pose'] = pred_pose.detach().cpu().numpy() all_preds['smpl_beta'] = pred_betas.detach().cpu().numpy() all_preds['camera'] = pred_cam.detach().cpu().numpy() all_preds['vertices'] = pred_vertices.detach().cpu().numpy() all_preds['image_idx'] = \ kwargs['sample_idx'].detach().cpu().numpy().reshape((-1)) return all_preds