Spaces:
Starting
on
L40S
Starting
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta | |
from typing import Optional, Union | |
import torch | |
from scipy.optimize import linear_sum_assignment | |
import numpy as np | |
from detrsmpl.core.post_processing.bbox.assigners import build_assigner | |
from detrsmpl.core.post_processing.bbox.samplers import build_sampler | |
from detrsmpl.core.conventions.keypoints_mapping import (get_keypoint_idx, | |
convert_kps) | |
from detrsmpl.utils.geometry import batch_rodrigues | |
from detrsmpl.utils.geometry import project_points | |
from detrsmpl.utils.misc import multi_apply | |
from ..backbones.builder import build_backbone | |
from ..body_models.builder import build_body_model | |
from ..heads.builder import build_head | |
from ..losses.builder import build_loss | |
from ..necks.builder import build_neck | |
from .base_architecture import BaseArchitecture | |
# from mmdet.core import bbox2result | |
class DETRLoss(BaseArchitecture, metaclass=ABCMeta): | |
def __init__( | |
self, | |
body_model_train: Optional[Union[dict, None]] = None, | |
body_model_test: Optional[Union[dict, None]] = None, | |
convention: Optional[str] = 'human_data', | |
loss_keypoints2d: Optional[Union[dict, None]] = None, | |
loss_keypoints3d: Optional[Union[dict, None]] = None, | |
loss_vertex: Optional[Union[dict, None]] = None, | |
loss_smpl_pose: Optional[Union[dict, None]] = None, | |
loss_smpl_betas: Optional[Union[dict, None]] = None, | |
loss_camera: Optional[Union[dict, None]] = None, | |
loss_cls: Optional[Union[dict, | |
None]] = dict(type='CrossEntropyLoss', | |
bg_cls_weight=0.1, | |
use_sigmoid=False, | |
loss_weight=1.0, | |
class_weight=1.0), | |
loss_bbox=dict(type='L1Loss', loss_weight=5.0), | |
loss_iou=dict(type='GIoULoss', loss_weight=2.0), | |
init_cfg: Optional[Union[list, dict, None]] = None, | |
train_cfg: | |
Optional[Union[dict, None]] = dict(assigner=dict( | |
type='HungarianAssigner', | |
kp3d_cost=dict( | |
type='Keypoints3DCost', convention='smpl_54', weight=5.0), | |
kp2d_cost=dict( | |
type='Keypoints2DCost', convention='smpl_54', weight=5.0), | |
# cls_cost=dict(type='ClassificationCost', weight=1.), | |
# reg_cost=dict(type='BBoxL1Cost', weight=5.0), | |
# iou_cost=dict( | |
# type='IoUCost', iou_mode='giou', weight=2.0)) | |
)), | |
test_cfg: Optional[Union[dict, None]] = None): | |
super(DETRLoss, self).__init__(init_cfg) | |
if train_cfg: | |
assert 'assigner' in train_cfg, 'assigner should be provided '\ | |
'when train_cfg is set.' | |
assigner = train_cfg['assigner'] | |
# TODO: update these | |
# assert loss_cls['loss_weight'] == assigner['kp3d_cost']['weight'], \ | |
# 'The classification weight for loss and matcher should be' \ | |
# 'exactly the same.' | |
# assert loss_bbox['loss_weight'] == assigner['kp3d_cost'][ | |
# 'weight'], 'The regression L1 weight for loss and matcher ' \ | |
# 'should be exactly the same.' | |
# assert loss_iou['loss_weight'] == assigner['kp3d_cost']['weight'], \ | |
# 'The regression iou weight for loss and matcher should be' \ | |
# 'exactly the same.' | |
self.assigner = build_assigner(assigner) | |
# DETR sampling=False, so use PseudoSampler | |
sampler_cfg = dict(type='PseudoSampler') | |
self.sampler = build_sampler(sampler_cfg, context=self) | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
# build loss | |
self.loss_keypoints2d = build_loss(loss_keypoints2d) | |
self.loss_keypoints3d = build_loss(loss_keypoints3d) | |
self.loss_vertex = build_loss(loss_vertex) | |
self.loss_smpl_pose = build_loss(loss_smpl_pose) | |
self.loss_smpl_betas = build_loss(loss_smpl_betas) | |
self.loss_cls = build_loss(loss_cls) | |
self.loss_bbox = build_loss(loss_bbox) | |
self.loss_iou = build_loss(loss_iou) | |
self.body_model_train = build_body_model(body_model_train) | |
self.body_model_test = build_body_model(body_model_test) | |
self.convention = convention | |
def forward_train(self, preds, targets): | |
pass | |
def forward(self, preds, targets): | |
""" | |
Args: | |
img (Tensor): Input images of shape (N, C, H, W). | |
Typically these should be mean centered and std scaled. | |
img_metas (list[dict]): A List of image info dict where each dict | |
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
For details on the values of these keys see | |
:class:`mmdet.datasets.pipelines.Collect`. | |
gt_bboxes (list[Tensor]): Each item are the truth boxes for each | |
image in [tl_x, tl_y, br_x, br_y] format. | |
gt_labels (list[Tensor]): Class indices corresponding to each box | |
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding | |
boxes can be ignored when computing the loss. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
# super(SingleStageDetector, self).forward_train(img, img_metas) | |
# NOTE the batched image size information may be useful, e.g. | |
# in DETR, this is needed for the construction of masks, which is | |
# then used for the transformer_head. | |
pred_pose = preds['pred_pose'] | |
pred_betas = preds['pred_betas'] | |
pred_cameras = preds['pred_cameras'] | |
has_smpl = targets['has_smpl'] | |
gt_smpl_body_pose = targets[ | |
'smpl_body_pose'] # [bs_0: [ins_num, 23, 3]] | |
gt_smpl_global_orient = targets['smpl_global_orient'] | |
gt_smpl_body_pose = \ | |
[torch.cat((gt_smpl_global_orient[i].view(-1, 1, 3), | |
gt_smpl_body_pose[i]), dim=1).float() | |
for i in range(len(gt_smpl_body_pose))] | |
gt_smpl_betas = targets['smpl_betas'] | |
gt_smpl_transl = targets['smpl_transl'] | |
gt_keypoints2d = targets['keypoints2d'] | |
gt_keypoints3d = targets['keypoints3d'] # [bs_0: [N. K, D], ...] | |
img_metas = targets['img_metas'] | |
if 'has_keypoints3d' in targets: | |
has_keypoints3d = targets['has_keypoints3d'] | |
else: | |
has_keypoints3d = None | |
if 'has_keypoints2d' in targets: | |
has_keypoints2d = targets['has_keypoints2d'] | |
else: | |
has_keypoints2d = None | |
img = targets['img'] | |
batch_input_shape = tuple(img[0].size()[-2:]) | |
for img_meta in img_metas: | |
img_meta['batch_input_shape'] = batch_input_shape | |
L, B, N = pred_pose.shape[:3] | |
if self.body_model_train is not None: | |
pred_output = self.body_model_train( | |
betas=pred_betas.reshape(L * B * N, 10), | |
body_pose=pred_pose.reshape(L * B * N, 24, 3, 3)[:, 1:], | |
global_orient=pred_pose.reshape(L * B * N, 24, 3, | |
3)[:, 0].unsqueeze(1), | |
pose2rot=False, | |
num_joints=gt_keypoints2d[0].shape[1]) | |
pred_keypoints3d = pred_output['joints'].reshape(L, B, N, -1, 3) | |
pred_vertices = pred_output['vertices'].reshape(L, B, N, 6890, 3) | |
# loss | |
num_dec_layers = pred_pose.shape[0] | |
all_gt_smpl_body_pose_list = [ | |
gt_smpl_body_pose for _ in range(num_dec_layers) | |
] | |
all_gt_smpl_global_orient_list = [ | |
gt_smpl_global_orient for _ in range(num_dec_layers) | |
] | |
all_gt_smpl_betas_list = [gt_smpl_betas for _ in range(num_dec_layers)] | |
all_gt_smpl_transl_list = [ | |
gt_smpl_transl for _ in range(num_dec_layers) | |
] | |
all_gt_keypoints2d_list = [ | |
gt_keypoints2d for _ in range(num_dec_layers) | |
] | |
all_gt_keypoints3d_list = [ | |
gt_keypoints3d for _ in range(num_dec_layers) | |
] | |
all_has_smpl_list = [has_smpl for _ in range(num_dec_layers)] | |
all_has_keypoints3d_list = [ | |
has_keypoints3d for _ in range(num_dec_layers) | |
] | |
all_has_keypoints2d_list = [ | |
has_keypoints2d for _ in range(num_dec_layers) | |
] | |
all_gt_ignore_list = [None for _ in range(num_dec_layers)] | |
img_metas_list = [img_metas for _ in range(num_dec_layers)] | |
# all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] | |
# all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] | |
# all_gt_bboxes_ignore_list = [ | |
# gt_bboxes_ignore for _ in range(num_dec_layers) | |
# ] | |
# computer loss for each layer | |
(kp2d_loss, kp3d_loss, vert_loss, pose_loss, beta_loss) = multi_apply( | |
self.compute_losses, pred_pose, pred_betas, pred_keypoints3d, | |
pred_vertices, pred_cameras, all_gt_smpl_body_pose_list, | |
all_gt_smpl_betas_list, all_gt_keypoints2d_list, | |
all_gt_keypoints3d_list, all_has_keypoints2d_list, | |
all_has_keypoints3d_list, all_has_smpl_list, img_metas_list, | |
all_gt_ignore_list) | |
losses = {} | |
losses['keypoints2d_loss'] = kp2d_loss[-1] | |
losses['keypoints3d_loss'] = kp3d_loss[-1] | |
losses['vertex_loss'] = vert_loss[-1] | |
losses['smpl_pose_loss'] = pose_loss[-1] | |
losses['smpl_betas_loss'] = beta_loss[-1] | |
# loss from other decoder layers | |
num_dec_layer = 0 | |
for (kp2d_loss_i, kp3d_loss_i, vert_loss_i, pose_loss_i, | |
beta_loss_i) in zip(kp2d_loss[:-1], kp3d_loss[:-1], | |
vert_loss[:-1], pose_loss[:-1], | |
beta_loss[:-1]): | |
losses[f'd{num_dec_layer}.keypoints2d_loss'] = kp2d_loss_i | |
losses[f'd{num_dec_layer}.keypoints3d_loss'] = kp3d_loss_i | |
losses[f'd{num_dec_layer}.vertex_loss'] = vert_loss_i | |
losses[f'd{num_dec_layer}.smpl_pose_loss'] = pose_loss_i | |
losses[f'd{num_dec_layer}.smpl_betas_loss'] = beta_loss_i | |
num_dec_layer += 1 | |
return losses | |
def compute_losses(self, | |
outputs_poses, | |
outputs_shapes, | |
outputs_kp3ds, | |
outputs_verts, | |
outputs_cameras, | |
all_gt_smpl_body_pose_list, | |
all_gt_smpl_betas_list, | |
all_gt_kp2d_list, | |
all_gt_kp3d_list, | |
all_has_keypoints2d_list, | |
all_has_keypoints3d_list, | |
all_has_smpl_list, | |
img_metas_list, | |
all_gt_ignore_list=None): | |
"""_summary_ | |
loss_single | |
get_targets | |
Args: | |
outputs_poses (_type_): with shape [B, N, 24, 3, 3] | |
outputs_shapes (_type_): _description_ | |
all_gt_smpl_body_pose_list (_type_): _description_ | |
all_gt_smpl_betas_list (_type_): _description_ | |
all_gt_kp2d_list (Torch.tensor): | |
all_gt_kp3d_list (list): with shape [B, N, K, D] | |
img_metas_list (_type_): _description_ | |
all_gt_ignore_list (_type_): _description_ | |
""" | |
num_img = outputs_poses.size(0) # batch_size | |
all_pred_smpl_pose_list = [outputs_poses[i] for i in range(num_img)] | |
all_pred_smpl_shape_list = [outputs_shapes[i] for i in range(num_img)] | |
all_pred_kp3d_list = [outputs_kp3ds[i] for i in range(num_img)] | |
all_pred_vert_list = [outputs_verts[i] for i in range(num_img)] | |
all_pred_cam_list = [outputs_cameras[i] for i in range(num_img)] | |
gt_bboxes_ignore_list = [all_gt_ignore_list for _ in range(num_img)] | |
if all_has_keypoints2d_list is None: | |
all_has_keypoints2d_list = [ | |
all_has_keypoints2d_list for _ in range(num_img) | |
] | |
if all_has_keypoints3d_list is None: | |
all_has_keypoints3d_list = [ | |
all_has_keypoints3d_list for _ in range(num_img) | |
] | |
if all_has_smpl_list is None: | |
all_has_smpl_list = [all_has_smpl_list for _ in range(num_img)] | |
# for each batch data | |
(kp2d_list, kp2d_weight_list, kp3d_list, kp3d_weight_list, | |
smpl_pose_list, smpl_pose_weight_list, smpl_shape_list, | |
smpl_shape_weight_list, vert_list, vert_weight_list, has_smpl_list, | |
has_keypoints2d_list, has_keypoints3d_list, pos_inds_list, | |
neg_inds_list) = multi_apply( | |
self.prepare_targets, | |
all_pred_smpl_pose_list, | |
all_pred_smpl_shape_list, | |
all_pred_kp3d_list, | |
all_pred_vert_list, | |
all_pred_cam_list, | |
all_gt_smpl_body_pose_list, | |
all_gt_smpl_betas_list, | |
all_gt_kp2d_list, | |
all_gt_kp3d_list, | |
all_has_keypoints2d_list, | |
all_has_keypoints3d_list, | |
all_has_smpl_list, | |
img_metas_list, | |
gt_bboxes_ignore_list, | |
) | |
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) | |
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) | |
K = outputs_kp3ds.shape[-2] | |
gt_kp2d = torch.cat(kp2d_list, 0) | |
kp2d_weight = torch.cat(kp2d_weight_list, 0) | |
pred_cam = outputs_cameras.reshape(-1, 3) | |
# pred_kp2d = torch.cat() | |
gt_kp3d = torch.cat(kp3d_list, 0) | |
kp3d_weight = torch.cat(kp3d_weight_list, 0) | |
pred_kp3d = outputs_kp3ds.reshape(-1, K, 3) | |
gt_smpl_pose = torch.cat(smpl_pose_list, 0) | |
smpl_pose_weight = torch.cat(smpl_pose_weight_list, 0) | |
pred_smpl_pose = outputs_poses.reshape(-1, 24, 3, 3) | |
gt_smpl_shape = torch.cat(smpl_shape_list, 0) | |
smpl_shape_weight = torch.cat(smpl_shape_weight_list, 0) | |
pred_smpl_shape = outputs_shapes.reshape(-1, 10) | |
gt_vert = torch.cat(vert_list, 0) | |
vert_weight = torch.cat(vert_weight_list, 0) | |
pred_verts = outputs_verts.reshape(-1, 6890, 3) | |
has_smpl = torch.cat(has_smpl_list, 0).squeeze() | |
has_keypoints2d = torch.cat(has_keypoints2d_list, 0).squeeze() | |
has_keypoints3d = torch.cat(has_keypoints3d_list, 0).squeeze() | |
# losses = {} | |
if self.loss_keypoints2d is not None: | |
keypoints2d_loss = self.compute_keypoints2d_loss( | |
pred_kp3d, pred_cam, gt_kp2d, has_keypoints2d=has_keypoints2d) | |
else: | |
keypoints2d_loss = 0.0 | |
if self.loss_keypoints3d is not None: | |
keypoints3d_loss = self.compute_keypoints3d_loss( | |
pred_kp3d, | |
gt_kp3d, | |
has_keypoints3d=has_keypoints3d, | |
) | |
else: | |
keypoints3d_loss = 0.0 | |
if self.loss_vertex is not None: | |
vertex_loss = self.compute_vertex_loss(pred_verts, | |
gt_vert, | |
has_smpl=has_smpl) | |
else: | |
vertex_loss = 0.0 | |
if self.loss_smpl_pose is not None: | |
smpl_pose_loss = self.compute_smpl_pose_loss(pred_smpl_pose, | |
gt_smpl_pose, | |
has_smpl=has_smpl) | |
else: | |
smpl_pose_loss = 0.0 | |
if self.loss_smpl_betas is not None: | |
smpl_betas_loss = self.compute_smpl_betas_loss(pred_smpl_shape, | |
gt_smpl_shape, | |
has_smpl=has_smpl) | |
else: | |
smpl_betas_loss = 0.0 | |
# if self.loss_iou is not None: | |
# losses['iou_loss'] = self.loss_iou() | |
# if self.loss_bbox is not None: | |
# losses['bbox_loss'] = self.loss_bbox() | |
# if self.loss_cls is not None: | |
# losses['cls_loss'] = self.loss_bbox() | |
return (keypoints2d_loss, keypoints3d_loss, vertex_loss, | |
smpl_pose_loss, smpl_betas_loss) | |
def prepare_targets(self, pred_smpl_pose, pred_smpl_shape, pred_kp3d, | |
pred_vert, pred_cam, gt_smpl_pose, gt_smpl_shape, | |
gt_kp2d, gt_kp3d, has_keypoints2d, has_keypoints3d, | |
has_smpl, img_meta, gt_bboxes_ignore): | |
"""_summary_ | |
Args: | |
all_pred_smpl_pose (_type_): _description_ | |
all_pred_smpl_shape (_type_): _description_ | |
all_pred_kp3d (_type_): _description_ | |
all_pred_vert (_type_): _description_ | |
all_gt_smpl_body_pose (_type_): _description_ | |
all_gt_smpl_betas (_type_): _description_ | |
all_gt_kp2d (_type_): _description_ | |
all_gt_kp3d (_type_): with shape [N, K, D] | |
img_meta (_type_): _description_ | |
gt_bboxes_ignore (_type_): _description_ | |
""" | |
num_query = pred_smpl_pose.shape[0] | |
assign_result = self.assigner.assign(pred_smpl_pose, pred_smpl_shape, | |
pred_kp3d, pred_vert, pred_cam, | |
gt_smpl_pose, gt_smpl_shape, | |
gt_kp2d, gt_kp3d, has_keypoints2d, | |
has_keypoints3d, has_smpl, | |
img_meta, gt_bboxes_ignore) | |
gt_smpl_pose = gt_smpl_pose.float() | |
gt_smpl_shape = gt_smpl_shape.float() | |
gt_kp2d = gt_kp2d.float() | |
gt_kp3d = gt_kp3d.float() | |
has_keypoints2d = has_keypoints2d.float() | |
has_keypoints3d = has_keypoints3d.float() | |
has_smpl = has_smpl.float() | |
sampling_result = self.sampler.sample(assign_result, pred_smpl_pose, | |
gt_smpl_pose) | |
pos_inds = sampling_result.pos_inds | |
neg_inds = sampling_result.neg_inds | |
# img_h, img_w, _ = img_meta['img_shape'] | |
# kp2d target | |
kp2d_targets = torch.zeros_like(pred_kp3d[..., :2]) | |
kp2d_weights = torch.zeros_like(pred_kp3d[..., :2]) | |
kp2d_targets[pos_inds] = gt_kp2d[sampling_result.pos_assigned_gt_inds][ | |
..., :2] | |
kp2d_weights[pos_inds] = gt_kp2d[sampling_result.pos_assigned_gt_inds][ | |
..., [2]].repeat(1, 1, 2) | |
kp2d_targets = torch.cat( | |
[kp2d_targets, kp2d_weights[..., 0].unsqueeze(-1)], dim=-1) | |
# kp3d target | |
kp3d_targets = torch.zeros_like(pred_kp3d) | |
kp3d_weights = torch.zeros_like(pred_kp3d) | |
kp3d_targets[pos_inds] = gt_kp3d[sampling_result.pos_assigned_gt_inds][ | |
..., :3] | |
kp3d_weights[pos_inds] = gt_kp3d[sampling_result.pos_assigned_gt_inds][ | |
..., [3]].repeat(1, 1, 3) | |
kp3d_targets = torch.cat( | |
[kp3d_targets, kp3d_weights[..., 0].unsqueeze(-1)], dim=-1) | |
# smpl_pose target | |
smpl_pose_targets = torch.zeros_like(pred_smpl_pose) | |
smpl_pose_weights = torch.zeros_like(pred_smpl_pose) | |
gt_smpl_pose_rotmat = batch_rodrigues(gt_smpl_pose.view(-1, 3)).view( | |
-1, 24, 3, 3) | |
smpl_pose_targets[pos_inds] = gt_smpl_pose_rotmat[ | |
sampling_result.pos_assigned_gt_inds] | |
smpl_pose_weights[pos_inds] = 1.0 | |
# smpl_beta target | |
smpl_shape_targets = torch.zeros_like(pred_smpl_shape) | |
smpl_shape_weights = torch.zeros_like(pred_smpl_shape) | |
smpl_shape_targets[pos_inds] = gt_smpl_shape[ | |
sampling_result.pos_assigned_gt_inds] | |
smpl_shape_weights[pos_inds] = 1.0 | |
# verts | |
if self.body_model_train is not None: | |
gt_output = self.body_model_train( | |
betas=gt_smpl_shape, | |
body_pose=gt_smpl_pose_rotmat[:, 1:], | |
global_orient=gt_smpl_pose_rotmat[:, 0].unsqueeze(1), | |
pose2rot=False) | |
gt_vertices = gt_output['vertices'] | |
gt_model_joints = gt_output['joints'] | |
vert_targets = torch.zeros_like(pred_vert) | |
vert_weights = torch.zeros_like(pred_vert) | |
vert_targets[pos_inds] = gt_vertices[ | |
sampling_result.pos_assigned_gt_inds] | |
vert_weights[pos_inds] = 1.0 | |
if has_keypoints2d is not None: | |
has_keypoints2d_ = torch.zeros( | |
(num_query, 1)).to(smpl_pose_targets.device) | |
has_keypoints2d_[pos_inds] = has_keypoints2d[ | |
sampling_result.pos_assigned_gt_inds] | |
else: | |
has_keypoints2d_ = None | |
if has_keypoints3d is not None: | |
has_keypoints3d_ = torch.zeros( | |
(num_query, 1)).to(smpl_pose_targets.device) | |
has_keypoints3d_[pos_inds] = has_keypoints3d[ | |
sampling_result.pos_assigned_gt_inds] | |
else: | |
has_keypoints3d_ = None | |
if has_smpl is not None: | |
has_smpl_ = torch.zeros( | |
(num_query, 1)).to(smpl_pose_targets.device) | |
# if len(sampling_result.pos_assigned_gt_inds) == 1: | |
# has_smpl_[pos_inds] = has_smpl | |
# else: | |
has_smpl_[pos_inds] = has_smpl[ | |
sampling_result.pos_assigned_gt_inds] | |
else: | |
has_smpl_ = None | |
return (kp2d_targets, kp2d_weights, kp3d_targets, kp3d_weights, | |
smpl_pose_targets, smpl_pose_weights, smpl_shape_targets, | |
smpl_shape_weights, vert_targets, vert_weights, has_smpl_, | |
has_keypoints2d_, has_keypoints3d_, pos_inds, neg_inds) | |
def forward_test(self, img, img_metas, **kwargs): | |
batch_input_shape = tuple(img[0].size()[-2:]) | |
for img_meta in img_metas: | |
img_meta['batch_input_shape'] = batch_input_shape | |
features = self.backbone(img) | |
if self.neck is not None: | |
features = self.neck(features) | |
pred_pose, pred_betas, pred_cam, _, _ = \ | |
self.head(features, img_metas) | |
# pred_pose = pred_pose[-1] | |
# pred_betas = pred_betas[-1] | |
# pred_cam = pred_cam[-1] | |
L, B, N = pred_pose.shape[:3] | |
if self.body_model_test is not None: | |
pred_output = self.body_model_test( | |
betas=pred_betas.reshape(L * B * N, 10), | |
body_pose=pred_pose.reshape(L * B * N, 24, 3, 3)[:, 1:], | |
global_orient=pred_pose.reshape(L * B * N, 24, 3, | |
3)[:, 0].unsqueeze(1), | |
pose2rot=False) | |
else: | |
raise ValueError('Please provide a builded body model.') | |
pred_keypoints_3d = pred_output['joints'].reshape(L, B, N, -1, 3) | |
pred_keypoints_3d = (pred_keypoints_3d - | |
pred_keypoints_3d[..., [0], :]) | |
pred_keypoints_3d = pred_keypoints_3d.detach().cpu().numpy() | |
# pred_vertices = pred_output['vertices'].reshape(L, B, N, 6890, 3) | |
pred_cam = pred_cam.detach().cpu().numpy() | |
pred_pose = pred_pose.detach().cpu().numpy() | |
pred_betas = pred_betas.detach().cpu().numpy() | |
# batch, instance_num, kp_num, 4 | |
gt_keypoints3d = kwargs['keypoints3d'].repeat([1, N, 1, 1]).clone() | |
# keypoints3d_mask = kwargs['keypoints3d_mask'] | |
gt_keypoints3d = gt_keypoints3d.detach().cpu().numpy() | |
# gt_keypoints3d, _ = convert_kps( | |
# gt_keypoints3d, | |
# src='human_data', | |
# dst='h36m') | |
cost = np.sum((pred_keypoints_3d[-1] - gt_keypoints3d[..., :3]), | |
axis=(2, 3)) | |
index = np.argmin(abs(cost), -1) | |
pred_keypoints_3d_ = [] | |
pred_pose_ = [] | |
pred_betas_ = [] | |
pred_cam_ = [] | |
for batch_i in range(B): | |
ind = index[batch_i] | |
pred_keypoints_3d_.append(pred_keypoints_3d[-1, batch_i, ind]) | |
pred_pose_.append(pred_pose[-1, batch_i, ind]) | |
pred_betas_.append(pred_betas[-1, batch_i, ind]) | |
pred_cam_.append(pred_cam[-1, batch_i, ind]) | |
# for img_id in range(len(img_metas)): | |
# pred_pose_ = pred_pose[:, img_id] | |
# pred_betas_ = pred_betas[:, img_id] | |
# pred_cam_ = pred_cam[:, img_id] | |
# pred_keypoints_3d_ = pred_keypoints_3d[:, img_id] | |
# pred_vertices_ = pred_vertices[:, img_id] | |
# img_shape_ = img_metas[img_id]['img_shape'] | |
# result_list.append() | |
all_preds = {} | |
all_preds['keypoints_3d'] = np.array(pred_keypoints_3d_) | |
all_preds['smpl_pose'] = np.array(pred_pose_) | |
all_preds['smpl_beta'] = np.array(pred_betas_) | |
all_preds['camera'] = np.array(pred_cam_) | |
# all_preds['vertices'] = pred_vertices.detach().cpu().numpy() | |
image_path = [] | |
for img_meta in img_metas: | |
image_path.append(img_meta['image_path']) | |
all_preds['image_path'] = image_path | |
all_preds['image_idx'] = kwargs['sample_idx'] | |
return all_preds | |
# loss | |
def compute_keypoints3d_loss( | |
self, | |
pred_keypoints3d: torch.Tensor, | |
gt_keypoints3d: torch.Tensor, | |
has_keypoints3d: Optional[torch.Tensor] = None): | |
"""Compute loss for 3d keypoints.""" | |
keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1) | |
keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3) | |
pred_keypoints3d = pred_keypoints3d.float() | |
gt_keypoints3d = gt_keypoints3d[:, :, :3].float() | |
# currently, only mpi_inf_3dhp and h36m have 3d keypoints | |
# both datasets have right_hip_extra and left_hip_extra | |
right_hip_idx = get_keypoint_idx('right_hip_extra', self.convention) | |
left_hip_idx = get_keypoint_idx('left_hip_extra', self.convention) | |
gt_pelvis = (gt_keypoints3d[:, right_hip_idx, :] + | |
gt_keypoints3d[:, left_hip_idx, :]) / 2 | |
pred_pelvis = (pred_keypoints3d[:, right_hip_idx, :] + | |
pred_keypoints3d[:, left_hip_idx, :]) / 2 | |
gt_keypoints3d = gt_keypoints3d - gt_pelvis[:, None, :] | |
pred_keypoints3d = pred_keypoints3d - pred_pelvis[:, None, :] | |
loss = self.loss_keypoints3d(pred_keypoints3d, | |
gt_keypoints3d, | |
reduction_override='none') | |
# If has_keypoints3d is not None, then computes the losses on the | |
# instances that have ground-truth keypoints3d. | |
# But the zero confidence keypoints will be included in mean. | |
# Otherwise, only compute the keypoints3d | |
# which have positive confidence. | |
# has_keypoints3d is None when the key has_keypoints3d | |
# is not in the datasets | |
if has_keypoints3d is None: | |
valid_pos = keypoints3d_conf > 0 | |
if keypoints3d_conf[valid_pos].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_keypoints3d) | |
loss = torch.sum(loss * keypoints3d_conf) | |
loss /= keypoints3d_conf[valid_pos].numel() | |
else: | |
keypoints3d_conf = keypoints3d_conf[has_keypoints3d == 1] | |
if keypoints3d_conf.shape[0] == 0: | |
return torch.Tensor([0]).type_as(gt_keypoints3d) | |
loss = loss[has_keypoints3d == 1] | |
loss = (loss * keypoints3d_conf).mean() | |
return loss | |
def compute_keypoints2d_loss( | |
self, | |
pred_keypoints3d: torch.Tensor, | |
pred_cam: torch.Tensor, | |
gt_keypoints2d: torch.Tensor, | |
img_res: Optional[int] = 512, | |
focal_length: Optional[int] = 5000., | |
has_keypoints2d: Optional[torch.Tensor] = None): | |
"""Compute loss for 2d keypoints.""" | |
keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1) | |
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) | |
gt_keypoints2d = gt_keypoints2d[:, :, :2].float() | |
pred_keypoints2d = project_points(pred_keypoints3d, | |
pred_cam, | |
focal_length=focal_length, | |
img_res=img_res) | |
# Normalize keypoints to [-1,1] | |
# The coordinate origin of pred_keypoints_2d is | |
# the center of the input image. | |
pred_keypoints2d = 2 * pred_keypoints2d / (img_res - 1) | |
# The coordinate origin of gt_keypoints_2d is | |
# the top left corner of the input image. | |
gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1 | |
loss = self.loss_keypoints2d(pred_keypoints2d, | |
gt_keypoints2d, | |
reduction_override='none') | |
# If has_keypoints2d is not None, then computes the losses on the | |
# instances that have ground-truth keypoints2d. | |
# But the zero confidence keypoints will be included in mean. | |
# Otherwise, only compute the keypoints2d | |
# which have positive confidence. | |
# has_keypoints2d is None when the key has_keypoints2d | |
# is not in the datasets | |
if has_keypoints2d is None: | |
valid_pos = keypoints2d_conf > 0 | |
if keypoints2d_conf[valid_pos].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_keypoints2d) | |
loss = torch.sum(loss * keypoints2d_conf) | |
loss /= keypoints2d_conf[valid_pos].numel() | |
else: | |
keypoints2d_conf = keypoints2d_conf[has_keypoints2d == 1] | |
if keypoints2d_conf.shape[0] == 0: | |
return torch.Tensor([0]).type_as(gt_keypoints2d) | |
loss = loss[has_keypoints2d == 1] | |
loss = (loss * keypoints2d_conf).mean() | |
return loss | |
def compute_vertex_loss(self, pred_vertices: torch.Tensor, | |
gt_vertices: torch.Tensor, has_smpl: torch.Tensor): | |
"""Compute loss for vertices.""" | |
gt_vertices = gt_vertices.float() | |
conf = has_smpl.float().view(-1, 1, 1) | |
conf = conf.repeat(1, gt_vertices.shape[1], gt_vertices.shape[2]) | |
loss = self.loss_vertex(pred_vertices, | |
gt_vertices, | |
reduction_override='none') | |
valid_pos = conf > 0 | |
if conf[valid_pos].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_vertices) | |
loss = torch.sum(loss * conf) / conf[valid_pos].numel() | |
return loss | |
def compute_smpl_pose_loss(self, pred_pose: torch.Tensor, | |
gt_pose: torch.Tensor, has_smpl: torch.Tensor): | |
"""Compute loss for smpl pose.""" | |
conf = has_smpl.float().view(-1) | |
valid_pos = conf > 0 | |
if conf[valid_pos].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_pose) | |
pred_pose = pred_pose[valid_pos] | |
gt_pose = gt_pose[valid_pos] | |
conf = conf[valid_pos] | |
# gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3) | |
loss = self.loss_smpl_pose(pred_pose, | |
gt_pose, | |
reduction_override='none') | |
loss = loss.view(loss.shape[0], -1).mean(-1) | |
loss = torch.mean(loss * conf) | |
return loss | |
def compute_smpl_betas_loss(self, pred_betas: torch.Tensor, | |
gt_betas: torch.Tensor, | |
has_smpl: torch.Tensor): | |
"""Compute loss for smpl betas.""" | |
conf = has_smpl.float().view(-1) | |
valid_pos = conf > 0 | |
if conf[valid_pos].numel() == 0: | |
return torch.Tensor([0]).type_as(gt_betas) | |
pred_betas = pred_betas[valid_pos] | |
gt_betas = gt_betas[valid_pos] | |
conf = conf[valid_pos] | |
loss = self.loss_smpl_betas(pred_betas, | |
gt_betas, | |
reduction_override='none') | |
loss = loss.view(loss.shape[0], -1).mean(-1) | |
loss = torch.mean(loss * conf) | |
return loss | |
def compute_camera_loss(self, cameras: torch.Tensor): | |
"""Compute loss for predicted camera parameters.""" | |
loss = self.loss_camera(cameras) | |
return loss | |