Spaces:
Running
on
L40S
Running
on
L40S
import copy | |
import os | |
import math | |
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torchvision.ops.boxes import nms | |
from torch import Tensor | |
from util import box_ops | |
from util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy, | |
get_world_size, interpolate, | |
is_dist_avail_and_initialized, inverse_sigmoid) | |
from .utils import PoseProjector, sigmoid_focal_loss, MLP, OKSLoss | |
from typing import Optional, Union | |
from detrsmpl.core.conventions.keypoints_mapping import (get_keypoint_idx, | |
convert_kps) | |
from detrsmpl.utils.geometry import (batch_rodrigues, project_points_new) | |
from config.config import cfg | |
from util.human_models import smpl_x | |
from detrsmpl.utils.transforms import rotmat_to_aa | |
class SetCriterion(nn.Module): | |
def __init__(self, | |
num_classes, | |
matcher, | |
weight_dict, | |
focal_alpha, | |
losses, | |
num_box_decoder_layers=2, | |
num_hand_face_decoder_layers=4, | |
num_body_points=17, | |
num_hand_points=6, | |
num_face_points=6, | |
smpl_loss_config=None, | |
convention='smplx_137'): | |
super().__init__() | |
self.num_classes = num_classes | |
self.matcher = matcher | |
self.weight_dict = weight_dict | |
self.losses = losses | |
self.focal_alpha = focal_alpha | |
self.vis = 0.1 | |
self.abs = 1 | |
self.num_body_points = num_body_points | |
self.num_hand_points = num_hand_points | |
self.num_face_points = num_face_points | |
self.num_box_decoder_layers = num_box_decoder_layers | |
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
self.convention = convention | |
self.body_oks = OKSLoss(linear=True, | |
num_keypoints=num_body_points, | |
eps=1e-6, | |
reduction='mean', | |
loss_weight=1.0) | |
self.hand_oks = OKSLoss(linear=True, | |
num_keypoints=num_hand_points, | |
eps=1e-6, | |
reduction='mean', | |
loss_weight=1.0) | |
self.face_oks = OKSLoss(linear=True, | |
num_keypoints=num_face_points, | |
eps=1e-6, | |
reduction='mean', | |
loss_weight=1.0) | |
def loss_labels(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
log=True): | |
"""Classification loss (Binary focal loss) targets dicts must contain | |
the key "labels" containing a tensor of dim [nb_target_boxes]""" | |
indices = indices[0] | |
assert 'pred_logits' in outputs | |
src_logits = outputs['pred_logits'] | |
target_classes_o = torch.cat( | |
[t['labels'][J] for t, (_, J) in zip(targets, indices)]) | |
target_classes = torch.full(src_logits.shape[:2], | |
self.num_classes, | |
dtype=torch.int64, | |
device=src_logits.device) | |
target_classes[idx] = target_classes_o | |
target_classes_onehot = torch.zeros([ | |
src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 | |
], | |
dtype=src_logits.dtype, | |
layout=src_logits.layout, | |
device=src_logits.device) | |
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) | |
target_classes_onehot = target_classes_onehot[:, :, :-1] | |
loss_ce = sigmoid_focal_loss(src_logits, | |
target_classes_onehot, | |
num_boxes, | |
alpha=self.focal_alpha, | |
gamma=2) * src_logits.shape[1] | |
losses = {'loss_ce': loss_ce} | |
if log: | |
# TODO this should probably be a separate loss, not hacked in this one here | |
losses['class_error'] = 100 - accuracy(src_logits[idx], | |
target_classes_o)[0] | |
return losses | |
def loss_cardinality(self, outputs, targets, indices, num_boxes, | |
data_batch): | |
"""Compute the cardinality error, ie the absolute error in the number | |
of predicted non-empty boxes This is not really a loss, it is intended | |
for logging purposes only. | |
It doesn't propagate gradients | |
""" | |
pred_logits = outputs['pred_logits'] | |
device = pred_logits.device | |
tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets], | |
device=device) | |
if tgt_lengths == 0: | |
return {'cardinality_error': pred_logits.sum()*0} | |
# Count the number of predictions that are NOT "no-object" (which is the last class) | |
card_pred = (pred_logits.argmax(-1) != | |
pred_logits.shape[-1] - 1).sum(1) | |
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) | |
losses = {'cardinality_error': card_err} | |
return losses | |
def loss_keypoints(self, outputs, targets, indices, | |
idx, num_boxes, data_batch, | |
face_hand_kpt=False): | |
"""Compute the losses related to the keypoints.""" | |
indices = indices[0] | |
losses = {} | |
device = outputs['pred_logits'].device | |
############################################################ | |
# body | |
############################################################ | |
src_body_keypoints = outputs['pred_keypoints'][idx] # xyxyvv | |
if len(src_body_keypoints) == 0: | |
losses.append({ | |
'loss_keypoints': src_body_keypoints.sum() * 0 + \ | |
outputs['pred_smpl_cam'][idx].float().sum()*0, | |
'loss_oks': src_body_keypoints.sum() * torch.as_tensor(0., device=device), | |
}) | |
else: | |
Z_pred = src_body_keypoints[:, 0:(self.num_body_points * 2)] # [2, 2*14] | |
V_pred = src_body_keypoints[:, (self.num_body_points * 2):] | |
targets_body_keypoints = torch.cat( | |
[t['keypoints'][i] for t, (_, i) in zip(targets, indices)], | |
dim=0) | |
targets_area = torch.cat( | |
[t['area'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_body_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['body_bbox_valid'], indices)], dim=0) | |
Z_gt = targets_body_keypoints[:, 0:(self.num_body_points * 2)] | |
V_gt: torch.Tensor = targets_body_keypoints[:, (self.num_body_points * 2):] | |
body_kps_conf = V_gt.sum(-1)>0 | |
body_num_boxes = (body_kps_conf * target_body_boxes_conf).sum() | |
oks_loss = self.body_oks(Z_pred, | |
Z_gt, | |
V_gt, | |
targets_area, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None) | |
oks_loss*= body_kps_conf * target_body_boxes_conf | |
pose_loss = F.l1_loss(Z_pred, Z_gt, reduction='none') | |
pose_loss = pose_loss * V_gt.repeat_interleave(2, dim=1) | |
pose_loss = pose_loss.sum(-1) * target_body_boxes_conf | |
if body_num_boxes>0: | |
losses['loss_keypoints'] = pose_loss.sum() / body_num_boxes | |
losses['loss_oks'] = oks_loss.sum() / body_num_boxes | |
else: | |
losses['loss_keypoints'] = src_body_keypoints.sum() * torch.as_tensor(0., device=device) | |
losses['loss_oks'] = src_body_keypoints.sum() * torch.as_tensor(0., device=device) | |
############################################################ | |
# lhand | |
############################################################ | |
if 'pred_lhand_keypoints' in outputs and face_hand_kpt: | |
src_lhand_keypoints = outputs['pred_lhand_keypoints'][idx] # xyxyvv | |
if len(src_lhand_keypoints) == 0: | |
losses.update({ | |
'loss_lhand_keypoints': src_lhand_keypoints.sum() * torch.as_tensor(0., device=device), | |
'loss_lhand_oks':src_lhand_keypoints.sum() * torch.as_tensor(0., device=device), | |
}) | |
else: | |
Z_pred = src_lhand_keypoints[:, 0:(self.num_hand_points * 2)] # [2, 2*14] | |
V_pred = src_lhand_keypoints[:, (self.num_hand_points * 2):] | |
targets_lhand_keypoints = torch.cat( | |
[t['lhand_keypoints'][i] for t, (_, i) in zip(targets, indices)], | |
dim=0) # i is batch_size | |
targets_area = torch.cat( | |
[t['area'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_lhand_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) | |
Z_gt = targets_lhand_keypoints[:, 0:(self.num_hand_points * 2)] | |
V_gt: torch.Tensor = targets_lhand_keypoints[:, (self.num_hand_points * 2):] | |
lhand_kps_conf = V_gt.sum(-1)>0 | |
lhand_num_boxes = (lhand_kps_conf*target_lhand_boxes_conf).sum() | |
oks_loss = self.hand_oks(Z_pred, | |
Z_gt, | |
V_gt, | |
targets_area, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None) | |
oks_loss = oks_loss*lhand_kps_conf*target_lhand_boxes_conf | |
pose_loss = F.l1_loss(Z_pred, Z_gt, reduction='none') | |
pose_loss = pose_loss * V_gt.repeat_interleave(2, dim=1) | |
pose_loss = pose_loss.sum(-1)*target_lhand_boxes_conf | |
if lhand_num_boxes>0: | |
losses['loss_lhand_keypoints'] = pose_loss.sum() / lhand_num_boxes | |
losses['loss_lhand_oks'] = oks_loss.sum() / lhand_num_boxes | |
else: | |
losses['loss_lhand_keypoints'] = src_lhand_keypoints.sum() * torch.as_tensor(0., device=device) | |
losses['loss_lhand_oks'] = src_lhand_keypoints.sum() * torch.as_tensor(0., device=device) | |
############################################################ | |
# rhand | |
############################################################ | |
if 'pred_rhand_keypoints' in outputs and face_hand_kpt: | |
src_rhand_keypoints = outputs['pred_rhand_keypoints'][idx] # xyxyvv | |
if len(src_rhand_keypoints) == 0: | |
losses.update({ | |
'loss_rhand_keypoints': | |
src_rhand_keypoints.sum() * torch.as_tensor(0., device=device), | |
'loss_rhand_oks': | |
src_rhand_keypoints.sum() * torch.as_tensor(0., device=device), | |
}) | |
else: | |
Z_pred = src_rhand_keypoints[:, 0:(self.num_hand_points * 2)] # [2, 2*14] | |
V_pred = src_rhand_keypoints[:, (self.num_hand_points * 2):] | |
targets_rhand_keypoints = torch.cat( | |
[t['rhand_keypoints'][i] for t, (_, i) in zip(targets, indices)], | |
dim=0) | |
targets_area = torch.cat( | |
[t['area'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_rhand_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices)], dim=0) | |
Z_gt = targets_rhand_keypoints[:, 0:(self.num_hand_points * 2)] | |
V_gt: torch.Tensor = targets_rhand_keypoints[:, (self.num_hand_points * 2):] | |
rhand_kps_conf = V_gt.sum(-1)>0 | |
rhand_num_boxes = (rhand_kps_conf*target_rhand_boxes_conf).sum() | |
oks_loss = self.hand_oks(Z_pred, | |
Z_gt, | |
V_gt, | |
targets_area, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None) | |
oks_loss = oks_loss*rhand_kps_conf*target_rhand_boxes_conf | |
pose_loss = F.l1_loss(Z_pred, Z_gt, reduction='none') | |
pose_loss = pose_loss * V_gt.repeat_interleave(2, dim=1) | |
pose_loss = pose_loss.sum(-1)*target_rhand_boxes_conf | |
if rhand_num_boxes>0: | |
losses['loss_rhand_keypoints'] = pose_loss.sum() / rhand_num_boxes | |
losses['loss_rhand_oks'] = oks_loss.sum() / rhand_num_boxes | |
else: | |
losses['loss_rhand_keypoints'] = src_rhand_keypoints.sum() * torch.as_tensor(0., device=device) | |
losses['loss_rhand_oks'] = src_rhand_keypoints.sum() * torch.as_tensor(0., device=device) | |
############################################################ | |
# face | |
############################################################ | |
if 'pred_face_keypoints' in outputs and face_hand_kpt: | |
src_face_keypoints = outputs['pred_face_keypoints'][idx] # xyxyvv | |
if len(src_face_keypoints) == 0: | |
losses.update({ | |
'loss_face_keypoints': src_face_keypoints.sum() * 0, | |
'loss_face_oks': src_face_keypoints.sum() * 0, | |
}) | |
else: | |
Z_pred = src_face_keypoints[:, 0:(self.num_face_points * 2)] # [2, 2*14] | |
V_pred = src_face_keypoints[:, (self.num_face_points * 2):] | |
targets_face_keypoints = torch.cat( | |
[t['face_keypoints'][i] for t, (_, i) in zip(targets, indices)], | |
dim=0) | |
targets_area = torch.cat( | |
[t['area'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_face_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) | |
Z_gt = targets_face_keypoints[:, 0:(self.num_face_points * 2)] | |
V_gt: torch.Tensor = targets_face_keypoints[:, (self.num_face_points * 2):] | |
face_kps_conf = V_gt.sum(-1)>0 | |
face_num_boxes = (lhand_kps_conf*target_face_boxes_conf).sum() | |
oks_loss = self.face_oks(Z_pred, | |
Z_gt, | |
V_gt, | |
targets_area, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None) | |
oks_loss = oks_loss*face_kps_conf*target_face_boxes_conf | |
pose_loss = F.l1_loss(Z_pred, Z_gt, reduction='none') | |
pose_loss = pose_loss * V_gt.repeat_interleave(2, dim=1) | |
pose_loss = pose_loss.sum(-1)*target_face_boxes_conf | |
if face_num_boxes>0: | |
losses['loss_face_keypoints'] = pose_loss.sum() / face_num_boxes | |
losses['loss_face_oks'] = oks_loss.sum() / face_num_boxes | |
else: | |
losses['loss_face_keypoints'] = src_face_keypoints.sum() * torch.as_tensor(0., device=device) | |
losses['loss_face_oks'] = src_face_keypoints.sum() * torch.as_tensor(0., device=device) | |
return losses | |
def loss_smpl_pose(self, outputs, targets, indices, idx, num_boxes, | |
data_batch, face_hand_kpt=False): | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
pred_smpl_body_pose = outputs['pred_smpl_pose'][idx] # 22 | |
pred_smpl_lhand_pose = outputs['pred_smpl_lhand_pose'][idx] # 15 | |
pred_smpl_rhand_pose = outputs['pred_smpl_rhand_pose'][idx] # 15 | |
pred_smpl_jaw_pose = outputs['pred_smpl_jaw_pose'][idx] | |
pred_smplx_pose = torch.cat((pred_smpl_body_pose, pred_smpl_lhand_pose, | |
pred_smpl_rhand_pose, pred_smpl_jaw_pose), | |
dim=1) | |
targets_smpl_pose = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['smplx_pose'], indices)], | |
dim=0) | |
targets_smpl_pose = batch_rodrigues(targets_smpl_pose.view( | |
-1, 3)).view(-1, 53, 3, 3) | |
conf = torch.cat([ | |
t[i] for t, (_, i) in zip(data_batch['smplx_pose_valid'], indices) | |
], dim=0) | |
# conf = (conf.reshape(-1,53,3)[:,:,:,None]).repeat(1,1,1,3) | |
body_pose_valid = conf[:, :22].sum(-1) > 0 | |
lhand_pose_valid = conf[:, 22:37].sum(-1) > 0 | |
rhand_pose_valid = conf[:, 37:52].sum(-1) > 0 | |
face_pose_valid = conf[:, 52].sum(-1) > 0 | |
losses = {} | |
loss_smpl_pose = \ | |
F.l1_loss( | |
pred_smplx_pose, | |
targets_smpl_pose, | |
reduction='none' | |
) | |
loss_smpl_pose = loss_smpl_pose.sum([-1,-2]) * conf | |
if face_hand_kpt: | |
losses = { | |
'loss_smpl_pose_root': loss_smpl_pose[:, 0].sum() / (body_pose_valid.sum() + 1e-6), | |
'loss_smpl_pose_body': loss_smpl_pose[:, 1:22].sum() / (body_pose_valid.sum() + 1e-6), | |
'loss_smpl_pose_lhand': loss_smpl_pose[:, 22:37].sum() / (lhand_pose_valid.sum() + 1e-6), | |
'loss_smpl_pose_rhand': loss_smpl_pose[:, 37:52].sum() / (rhand_pose_valid.sum() + 1e-6), | |
'loss_smpl_pose_jaw': loss_smpl_pose[:, 52].sum() / (face_pose_valid.sum() + 1e-6), | |
} | |
else: | |
losses = { | |
'loss_smpl_pose_root': loss_smpl_pose[:, 0].sum() / (body_pose_valid.sum() + 1e-6), | |
'loss_smpl_pose_body': loss_smpl_pose[:, 1:22].sum() / (body_pose_valid.sum() + 1e-6), | |
'loss_smpl_pose_lhand': torch.as_tensor(0., device=device) * loss_smpl_pose[:, 22:37].sum()/(lhand_pose_valid.sum() + 1e-6), | |
'loss_smpl_pose_rhand': torch.as_tensor(0., device=device) * loss_smpl_pose[:, 37:52].sum() / (rhand_pose_valid.sum() + 1e-6), | |
'loss_smpl_pose_jaw': torch.as_tensor(0., device=device)*loss_smpl_pose[:, 52].sum() / (face_pose_valid.sum() + 1e-6), | |
} | |
return losses | |
def loss_smpl_beta(self, outputs, targets, indices, idx, num_boxes, | |
data_batch, face_hand_kpt=False): | |
indices = indices[0] | |
device = outputs['pred_logits'].device | |
pred_smpl_betas = outputs['pred_smpl_beta'][idx] | |
targets_smpl_betas = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['smplx_shape'], indices)], | |
dim=0) | |
losses = {} | |
conf = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_shape_valid'], indices)], dim=0) | |
if conf.sum() == 0: | |
return { | |
'loss_smpl_beta': pred_smpl_betas.sum() * 0 | |
} | |
loss_smpl_betas = \ | |
F.l1_loss( | |
pred_smpl_betas, | |
targets_smpl_betas, | |
reduction='none' | |
) | |
loss_smpl_betas = loss_smpl_betas.sum(-1) * conf | |
losses = {'loss_smpl_beta': loss_smpl_betas.sum() / (conf.sum() + 1e-6)} | |
return losses | |
def loss_smpl_expr(self, outputs, targets, indices, idx, num_boxes, | |
data_batch, face_hand_kpt=False): | |
indices = indices[0] | |
device = outputs['pred_logits'].device | |
pred_smpl_expr = outputs['pred_smpl_expr'][idx] | |
targets_smpl_expr = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr'], indices)], dim=0) | |
conf = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr_valid'], indices)], dim=0) | |
if conf.sum() == 0: | |
return { | |
'loss_smpl_expr': pred_smpl_expr.sum() * torch.as_tensor(0., device=device) | |
} | |
loss_smpl_expr = \ | |
F.l1_loss( | |
pred_smpl_expr, | |
targets_smpl_expr, | |
reduction='none' | |
) | |
loss_smpl_expr = loss_smpl_expr.sum(-1) * conf | |
losses = {} | |
if face_hand_kpt: | |
losses = {'loss_smpl_expr': loss_smpl_expr.sum() / (conf.sum() + 1e-6)} | |
else: | |
losses = {'loss_smpl_expr': torch.as_tensor(0., device=device)*loss_smpl_expr.sum() / (conf.sum() + 1e-6) } | |
return losses | |
def loss_smpl_kp3d(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
has_keypoints3d=None, | |
face_hand_kpt=False): | |
# supervision for keypoints3d wo/ ra | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() | |
# meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) | |
targets_smpl_kp3d = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['joint_cam'], indices)], | |
dim=0) | |
losses = {} | |
targets_kp3d_conf = targets_smpl_kp3d[:,:,3:].clone() | |
targets_smpl_kp3d = targets_smpl_kp3d[:,:,:3] | |
targets_is_3d = torch.cat([ | |
t[None, None].repeat(len(i), 1, 1) | |
for t, (_, i) in zip(data_batch['is_3D'], indices) | |
], dim=0) | |
targets_kp3d_conf = (targets_kp3d_conf * targets_is_3d) | |
pelvis_idx = get_keypoint_idx('pelvis', self.convention) | |
targets_pelvis = targets_smpl_kp3d[..., pelvis_idx, :] | |
pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] | |
targets_smpl_kp3d = targets_smpl_kp3d - targets_pelvis[:, None, :] | |
pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] | |
losses = {} | |
body_idx = smpl_x.joint_part['body'] | |
face_idx = smpl_x.joint_part['face'] | |
lhand_idx = smpl_x.joint_part['lhand'] | |
rhand_idx = smpl_x.joint_part['rhand'] | |
loss_smpl_kp3d = F.l1_loss(pred_smpl_kp3d, | |
targets_smpl_kp3d, | |
reduction='none') | |
body_kp3d_valid = targets_kp3d_conf[:, body_idx].sum([-1,-2]) > 0 | |
lhand_kp3d_valid = targets_kp3d_conf[:, lhand_idx].sum([-1,-2]) > 0 | |
rhand_kp3d_valid = targets_kp3d_conf[:, rhand_idx].sum([-1,-2]) > 0 | |
face_kp3d_valid = targets_kp3d_conf[:, face_idx].sum([-1,-2]) > 0 | |
loss_smpl_kp3d = loss_smpl_kp3d * targets_kp3d_conf # + outputs['pred_smpl_cam'][idx].float().sum()*0 | |
if face_hand_kpt: | |
losses['loss_smpl_body_kp3d'] = torch.sum(loss_smpl_kp3d[:, body_idx, :]) / (body_kp3d_valid.sum() + 1e-6) | |
losses['loss_smpl_lhand_kp3d'] = torch.sum(loss_smpl_kp3d[:, lhand_idx, :]) / (lhand_kp3d_valid.sum() + 1e-6) | |
losses['loss_smpl_rhand_kp3d'] = torch.sum(loss_smpl_kp3d[:, rhand_idx, :]) / (rhand_kp3d_valid.sum() + 1e-6) | |
losses['loss_smpl_face_kp3d'] = torch.sum(loss_smpl_kp3d[:, face_idx, :]) / (face_kp3d_valid.sum() + 1e-6) | |
else: | |
losses['loss_smpl_body_kp3d'] = torch.sum(loss_smpl_kp3d[:, body_idx, :]) / (body_kp3d_valid.sum() + 1e-6) | |
losses['loss_smpl_lhand_kp3d'] = torch.as_tensor(0., device=device)*torch.sum(loss_smpl_kp3d[:, lhand_idx, :]) / (lhand_kp3d_valid.sum() + 1e-6) | |
losses['loss_smpl_rhand_kp3d'] = torch.as_tensor(0., device=device)*torch.sum(loss_smpl_kp3d[:, rhand_idx, :]) / (rhand_kp3d_valid.sum() + 1e-6) | |
losses['loss_smpl_face_kp3d'] = torch.as_tensor(0., device=device)*torch.sum(loss_smpl_kp3d[:, face_idx, :]) / (face_kp3d_valid.sum() + 1e-6) | |
return losses | |
def loss_smpl_kp3d_ra(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
has_keypoints3d=None, | |
face_hand_kpt=False): | |
# supervision for keypoints3d w/ ra | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() | |
# meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) | |
targets_smpl_kp3d = torch.cat([ | |
t[i] for t, (_, i) in zip(data_batch['smplx_joint_cam'], indices)], | |
dim=0) | |
losses = {} | |
# if valid_num == 0: | |
# losses['loss_smpl_rhand_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
# losses['loss_smpl_body_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
# losses['loss_smpl_face_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
# losses['loss_smpl_lhand_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
# return losses | |
targets_kp3d_conf = targets_smpl_kp3d[:,:,3:].clone() | |
targets_smpl_kp3d = targets_smpl_kp3d[:,:,:3] | |
targets_is_3d = torch.cat([ | |
t[None, None].repeat(len(i), 1, 1) | |
for t, (_, i) in zip(data_batch['is_3D'], indices)],dim=0) | |
targets_kp3d_conf = (targets_kp3d_conf * targets_is_3d).repeat(1, 1, 3) | |
# targets_smpl_kp3d = targets_smpl_kp3d[..., :3].float() | |
pelvis_idx = get_keypoint_idx('pelvis', self.convention) | |
targets_pelvis = targets_smpl_kp3d[..., pelvis_idx, :] | |
pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] | |
targets_smpl_kp3d = targets_smpl_kp3d - targets_pelvis[:, None, :] | |
pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] | |
# calculate body, face and hand loss separately: | |
losses = {} | |
body_idx = smpl_x.joint_part['body'] | |
face_idx = smpl_x.joint_part['face'] | |
lhand_idx = smpl_x.joint_part['lhand'] | |
rhand_idx = smpl_x.joint_part['rhand'] | |
body_kp3d_valid = targets_kp3d_conf[:, body_idx].sum([-1,-2]) > 0 | |
lhand_kp3d_valid = targets_kp3d_conf[:, lhand_idx].sum([-1,-2]) > 0 | |
rhand_kp3d_valid = targets_kp3d_conf[:, rhand_idx].sum([-1,-2]) > 0 | |
face_kp3d_valid = targets_kp3d_conf[:, face_idx].sum([-1,-2]) > 0 | |
loss_smpl_body_kp3d = F.l1_loss(pred_smpl_kp3d[:, body_idx, :], | |
targets_smpl_kp3d[:, body_idx, :], | |
reduction='none') | |
loss_smpl_body_kp3d = torch.sum( | |
loss_smpl_body_kp3d * targets_kp3d_conf[:, body_idx, :]) | |
losses['loss_smpl_body_kp3d_ra'] = loss_smpl_body_kp3d / (body_kp3d_valid.sum() + 1e-6) | |
face_cam = pred_smpl_kp3d[:, face_idx, :] | |
neck_cam = pred_smpl_kp3d[:, smpl_x.neck_idx, None, :] | |
face_cam = face_cam - neck_cam | |
loss_smpl_face_kp3d = F.l1_loss(face_cam, | |
targets_smpl_kp3d[:, face_idx, :], | |
reduction='none') | |
loss_smpl_face_kp3d = torch.sum( | |
loss_smpl_face_kp3d * targets_kp3d_conf[:, face_idx, :]) | |
if face_hand_kpt: | |
losses['loss_smpl_face_kp3d_ra'] = (loss_smpl_face_kp3d / (face_kp3d_valid.sum() + 1e-6)) | |
else: | |
losses['loss_smpl_face_kp3d_ra'] = 0 * (loss_smpl_face_kp3d / (face_kp3d_valid.sum() + 1e-6)) | |
lhand_cam = pred_smpl_kp3d[:, lhand_idx, :] | |
lwrist_cam = pred_smpl_kp3d[:, smpl_x.lwrist_idx, None, :] | |
lhand_cam = lhand_cam - lwrist_cam | |
loss_smpl_lhand_kp3d = F.l1_loss(lhand_cam, | |
targets_smpl_kp3d[:, lhand_idx, :], | |
reduction='none') | |
loss_smpl_lhand_kp3d = torch.sum( | |
loss_smpl_lhand_kp3d * targets_kp3d_conf[:, lhand_idx, :]) | |
if face_hand_kpt: | |
losses['loss_smpl_lhand_kp3d_ra'] = (loss_smpl_lhand_kp3d / (lhand_kp3d_valid.sum() + 1e-6)) | |
else: | |
losses['loss_smpl_lhand_kp3d_ra'] = 0*(loss_smpl_lhand_kp3d / (lhand_kp3d_valid.sum() + 1e-6)) | |
rhand_cam = pred_smpl_kp3d[:, rhand_idx, :] | |
rwrist_cam = pred_smpl_kp3d[:, smpl_x.rwrist_idx, None, :] | |
rhand_cam = rhand_cam - rwrist_cam | |
loss_smpl_rhand_kp3d = F.l1_loss(rhand_cam, | |
targets_smpl_kp3d[:, rhand_idx, :], | |
reduction='none') | |
loss_smpl_rhand_kp3d = torch.sum( | |
loss_smpl_rhand_kp3d * targets_kp3d_conf[:, rhand_idx, :]) | |
if face_hand_kpt: | |
losses['loss_smpl_rhand_kp3d_ra'] = (loss_smpl_rhand_kp3d / (rhand_kp3d_valid.sum() + 1e-6)) | |
else: | |
losses['loss_smpl_rhand_kp3d_ra'] = 0*(loss_smpl_rhand_kp3d / (rhand_kp3d_valid.sum() + 1e-6)) | |
return losses | |
def loss_smpl_kp2d(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
focal_length=5000., | |
has_keypoints2d=None, | |
face_hand_kpt=False): | |
"""Compute loss for 2d keypoints.""" | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float()#.detach() | |
pred_cam = outputs['pred_smpl_cam'][idx].float() | |
targets_kp2d = torch.cat([t[i] for t, (_, i) in zip(data_batch['joint_img'], indices)], dim=0) | |
keypoints2d_conf = targets_kp2d[:,:,2:].clone() | |
targets_kp2d = targets_kp2d[:, :, :2].float() | |
targets_kp2d[:,:,0] = targets_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
targets_kp2d[:,:,1] = targets_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
# targets_kp2d = targets_kp2d*2-1 | |
img_wh = torch.cat([data_batch['img_shape'][i][None] for i in idx[0]], dim=0).flip(-1) | |
pred_smpl_kp2d = project_points_new( | |
points_3d=pred_smpl_kp3d, | |
pred_cam=pred_cam, | |
focal_length=focal_length, | |
camera_center=img_wh/2 | |
) | |
pred_smpl_kp2d = pred_smpl_kp2d / img_wh[:, None] | |
losses = {} | |
body_idx = smpl_x.joint_part['body'] | |
face_idx = smpl_x.joint_part['face'] | |
lhand_idx = smpl_x.joint_part['lhand'] | |
rhand_idx = smpl_x.joint_part['rhand'] | |
body_kp2d_valid = keypoints2d_conf[:, body_idx].sum([-1,-2]) > 0 | |
lhand_kp2d_valid = keypoints2d_conf[:, lhand_idx].sum([-1,-2]) > 0 | |
rhand_kp2d_valid = keypoints2d_conf[:, rhand_idx].sum([-1,-2]) > 0 | |
face_kp2d_valid = keypoints2d_conf[:, face_idx].sum([-1,-2]) > 0 | |
loss_smpl_kp2d = F.l1_loss(pred_smpl_kp2d, | |
targets_kp2d, | |
reduction='none') | |
loss_smpl_kp2d = loss_smpl_kp2d * keypoints2d_conf | |
# import mmcv | |
# import cv2 | |
# img = (data_batch['img'][0]*255).permute(1,2,0).int().detach().cpu().numpy() | |
if face_hand_kpt: | |
losses['loss_smpl_body_kp2d'] = torch.sum(loss_smpl_kp2d[:, body_idx, :]) / (body_kp2d_valid.sum() + 1e-6) | |
losses['loss_smpl_lhand_kp2d'] = torch.sum(loss_smpl_kp2d[:, lhand_idx, :]) / (lhand_kp2d_valid.sum() + 1e-6) | |
losses['loss_smpl_rhand_kp2d'] = torch.sum(loss_smpl_kp2d[:, rhand_idx, :]) / (rhand_kp2d_valid.sum() + 1e-6) | |
losses['loss_smpl_face_kp2d'] = torch.sum(loss_smpl_kp2d[:, face_idx, :]) / (face_kp2d_valid.sum() + 1e-6) | |
else: | |
losses['loss_smpl_body_kp2d'] = torch.sum(loss_smpl_kp2d[:, body_idx, :]) / (body_kp2d_valid.sum() + 1e-6) | |
losses['loss_smpl_lhand_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, lhand_idx, :]) / (lhand_kp2d_valid.sum() + 1e-6) | |
losses['loss_smpl_rhand_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, rhand_idx, :]) / (rhand_kp2d_valid.sum() + 1e-6) | |
losses['loss_smpl_face_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, face_idx, :]) / (face_kp2d_valid.sum() + 1e-6) | |
return losses | |
def loss_smpl_kp2d_ba(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
focal_length=5000., | |
has_keypoints2d=None, | |
face_hand_kpt=False): | |
"""Compute loss for 2d keypoints.""" | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
# pdb.set_trace() | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float()#.detach() | |
pred_cam = outputs['pred_smpl_cam'][idx].float() | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
targets_kp2d = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['joint_img'], indices)], | |
dim=0) | |
losses = {} | |
keypoints2d_conf = targets_kp2d[:,:,2:].clone() | |
targets_kp2d = targets_kp2d[:,:,:2] | |
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) | |
targets_kp2d = targets_kp2d[:, :, :2].float() | |
targets_kp2d[:, :, 0] = targets_kp2d[:, :, 0] / cfg.output_hm_shape[2] | |
targets_kp2d[:, :, 1] = targets_kp2d[:, :, 1] / cfg.output_hm_shape[1] | |
# targets_kp2d = targets_kp2d * 2 - 1 | |
img_wh = torch.cat([data_batch['img_shape'][i][None] for i in idx[0]], dim=0).flip(-1) | |
pred_smpl_kp2d = project_points_new( | |
points_3d=pred_smpl_kp3d, | |
pred_cam=pred_cam, | |
focal_length=focal_length, | |
camera_center=img_wh/2 | |
) | |
pred_smpl_kp2d = pred_smpl_kp2d / img_wh[:, None] | |
if valid_num == 0: | |
losses['loss_smpl_body_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_lhand_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_rhand_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_face_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
return losses | |
# rhand bbox | |
rhand_bbox_valid = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices) ], dim=0) | |
rhand_bbox_gt = torch.cat( | |
[t['rhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
rhand_bbox_gt = (box_ops.box_cxcywh_to_xyxy(rhand_bbox_gt). | |
reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) | |
num_rhand_bbox = rhand_bbox_valid.sum() | |
# lhand bbox | |
lhand_bbox_valid = torch.cat([ | |
t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) | |
lhand_bbox_gt = torch.cat( | |
[t['lhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
lhand_bbox_gt = (box_ops.box_cxcywh_to_xyxy(lhand_bbox_gt). | |
reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) | |
num_lhand_bbox = lhand_bbox_valid.sum() | |
# face bbox | |
face_bbox_valid = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) | |
face_bbox_gt = torch.cat( | |
[t['face_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
face_bbox_gt = (box_ops.box_cxcywh_to_xyxy(face_bbox_gt). | |
reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) | |
num_face_bbox = face_bbox_valid.sum() | |
img_shape = torch.cat( | |
[t[None].repeat(len(i), 1) for t, (_, i) in zip(data_batch['img_shape'], indices)], dim=0) | |
# joint_proj = (joint_proj / 2 + 0.5) | |
# joint_proj[:, :, 0] = joint_proj[:, :, 0] * img_shape[:, 1:] | |
# joint_proj[:, :, 1] = joint_proj[:, :, 1] * img_shape[:, :1] | |
if not (lhand_bbox_valid + rhand_bbox_valid + face_bbox_valid == 0).all(): | |
for part_name, bbox in ( | |
('lhand', lhand_bbox_gt), | |
('rhand', rhand_bbox_gt), | |
('face', face_bbox_gt)): | |
x = targets_kp2d[:, smpl_x.joint_part[part_name], 0] | |
y = targets_kp2d[:, smpl_x.joint_part[part_name], 1] | |
# trunc = joint_trunc[:, smpl_x.joint_part[part_name], 0] | |
trunc = keypoints2d_conf[:, smpl_x.joint_part[part_name], 0].clone() | |
# x in [0, 1]? bbox in [0, 1]. | |
x -= (bbox[:, None, 0] / img_shape[:, 1:]) | |
# x | |
x *= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0] + 1e-6)) | |
y -= (bbox[:, None, 1] / img_shape[:, :1]) | |
y *= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1] + 1e-6)) | |
# transformed to 0-1 bbox space | |
trunc *= ((x >= 0) * (x <= 1) * | |
(y >= 0) * (y <= 1)) | |
coord = torch.stack((x, y), 2) | |
targets_kp2d = torch.cat( | |
(targets_kp2d[:, :smpl_x.joint_part[part_name][0], :], coord, | |
targets_kp2d[:, smpl_x.joint_part[part_name][-1] + 1:, :]), | |
1) | |
x_pred = pred_smpl_kp2d[:, smpl_x.joint_part[part_name], 0] | |
y_pred = pred_smpl_kp2d[:, smpl_x.joint_part[part_name], 1] | |
# bbox: xyxy img_shape: hw | |
x_pred -= (bbox[:, None, 0] / img_shape[:, 1:]) | |
x_pred *= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0] + 1e-6)) | |
y_pred -= (bbox[:, None, 1] / img_shape[:, :1]) | |
y_pred *= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1] + 1e-6)) | |
coord_pred = torch.stack((x_pred, y_pred), 2) | |
trans = [] | |
for bid in range(coord_pred.shape[0]): | |
mask = trunc[bid] == 1 | |
if torch.sum(mask) == 0: | |
trans.append(torch.zeros((2)).float().cuda()) | |
else: | |
trans.append( | |
(-coord_pred[bid, mask, :2] + targets_kp2d[:, smpl_x.joint_part[part_name], :][bid, mask, :2]).mean(0)) | |
trans = torch.stack(trans)[:, None, :] | |
coord_pred = coord_pred + trans # global translation alignment | |
pred_smpl_kp2d = torch.cat( | |
(pred_smpl_kp2d[:, :smpl_x.joint_part[part_name][0], :], coord_pred, | |
pred_smpl_kp2d[:, smpl_x.joint_part[part_name][-1] + 1:, :]), | |
1) | |
loss_smpl_kp2d_ba = F.l1_loss(pred_smpl_kp2d, | |
targets_kp2d[:, :, :2], | |
reduction='none') | |
valid_pos = keypoints2d_conf > 0 | |
losses = {} | |
if keypoints2d_conf[valid_pos].numel() == 0: | |
return { | |
'loss_smpl_body_kp2d_ba': loss_smpl_kp2d_ba.sum()*0, | |
'loss_smpl_lhand_kp2d_ba': loss_smpl_kp2d_ba.sum()*0, | |
'loss_smpl_rhand_kp2d_ba': loss_smpl_kp2d_ba.sum()*0, | |
'loss_smpl_face_kp2d_ba': loss_smpl_kp2d_ba.sum()*0, | |
} | |
# loss /= targets_kp3d_conf[valid_pos].numel() | |
# 要改 | |
loss_smpl_kp2d_ba = loss_smpl_kp2d_ba * keypoints2d_conf | |
losses['loss_smpl_body_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['body'], :]) / num_boxes | |
if face_hand_kpt: | |
if num_lhand_bbox>0: | |
losses['loss_smpl_lhand_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['lhand'], :]) / num_lhand_bbox | |
else: | |
losses['loss_smpl_lhand_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 | |
if num_rhand_bbox>0: | |
losses['loss_smpl_rhand_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['rhand'], :]) / num_rhand_bbox | |
else: | |
losses['loss_smpl_rhand_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 | |
if num_face_bbox>0: | |
losses['loss_smpl_face_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['face'], :]) / num_face_bbox | |
else: | |
losses['loss_smpl_face_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 | |
else: | |
losses['loss_smpl_lhand_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['lhand'], :]) / num_lhand_bbox | |
losses['loss_smpl_rhand_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['rhand'], :]) / num_rhand_bbox | |
losses['loss_smpl_face_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['face'], :]) / num_face_bbox | |
return losses | |
def loss_boxes(self, outputs, targets, indices, | |
idx, num_boxes, data_batch, | |
face_hand_box=False): | |
"""Compute the losses related to the bounding boxes, the L1 regression | |
loss and the GIoU loss targets dicts must contain the key "boxes" | |
containing a tensor of dim [nb_target_boxes, 4] The target boxes are | |
expected in format (center_x, center_y, w, h), normalized by the image | |
size.""" | |
indices = indices[0] | |
device = outputs['pred_logits'].device | |
assert 'pred_boxes' in outputs | |
src_body_boxes = outputs['pred_boxes'][idx] | |
target_body_boxes = torch.cat( | |
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_body_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['body_bbox_valid'], indices)], dim=0) | |
loss_body_bbox = F.l1_loss(src_body_boxes, target_body_boxes, reduction='none') | |
loss_body_bbox = loss_body_bbox * target_body_boxes_conf[:,None] | |
losses = {} | |
losses['loss_body_bbox'] = loss_body_bbox.sum() / num_boxes | |
loss_body_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_body_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_body_boxes))) | |
loss_body_giou = loss_body_giou * target_body_boxes_conf | |
losses['loss_body_giou'] = loss_body_giou.sum() / num_boxes | |
if 'pred_lhand_boxes' in outputs and face_hand_box: | |
src_lhand_boxes = outputs['pred_lhand_boxes'][idx] | |
target_lhand_boxes = torch.cat( | |
[t['lhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_lhand_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) | |
# print(target_lhand_boxes_conf) | |
loss_lhand_bbox = F.l1_loss(src_lhand_boxes, target_lhand_boxes, reduction='none') | |
loss_lhand_bbox = loss_lhand_bbox * target_lhand_boxes_conf[:,None] | |
num_lhand_boxes = (target_lhand_boxes_conf>0).sum() | |
loss_lhand_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_lhand_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_lhand_boxes))) | |
loss_lhand_giou = loss_lhand_giou * target_lhand_boxes_conf | |
if num_lhand_boxes > 0: | |
losses['loss_lhand_bbox'] = loss_lhand_bbox.sum() / num_lhand_boxes | |
losses['loss_lhand_giou'] = loss_lhand_giou.sum() / num_lhand_boxes | |
else: | |
losses['loss_lhand_bbox'] = loss_lhand_bbox.sum() * 0 | |
losses['loss_lhand_giou'] = loss_lhand_giou.sum() * 0 | |
if 'pred_rhand_boxes' in outputs and face_hand_box: | |
src_rhand_boxes = outputs['pred_rhand_boxes'][idx] | |
target_rhand_boxes = torch.cat( | |
[t['rhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_rhand_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices)], dim=0) | |
loss_rhand_bbox = F.l1_loss(src_rhand_boxes, target_rhand_boxes, reduction='none') | |
loss_rhand_bbox = loss_rhand_bbox * target_rhand_boxes_conf[:,None] | |
num_rhand_boxes = (target_rhand_boxes_conf>0).sum() | |
loss_rhand_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_rhand_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_rhand_boxes))) | |
loss_rhand_giou = loss_rhand_giou * target_rhand_boxes_conf | |
if num_rhand_boxes > 0: | |
losses['loss_rhand_bbox'] = loss_rhand_bbox.sum() / num_rhand_boxes | |
losses['loss_rhand_giou'] = loss_rhand_giou.sum() / num_rhand_boxes | |
else: | |
losses['loss_rhand_bbox'] = loss_rhand_bbox.sum() * 0 | |
losses['loss_rhand_giou'] = loss_rhand_giou.sum() * 0 | |
if 'pred_face_boxes' in outputs and face_hand_box: | |
src_face_boxes = outputs['pred_face_boxes'][idx] | |
target_face_boxes = torch.cat( | |
[t['face_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_face_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) | |
loss_face_bbox = F.l1_loss(src_face_boxes, target_face_boxes, reduction='none') | |
loss_face_bbox = loss_face_bbox * target_face_boxes_conf[:,None] | |
num_face_boxes = (target_face_boxes_conf>0).sum() | |
loss_face_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_face_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_face_boxes))) | |
loss_face_giou = loss_face_giou * target_face_boxes_conf | |
if num_face_boxes > 0: | |
losses['loss_face_bbox'] = loss_face_bbox.sum() / num_face_boxes | |
losses['loss_face_giou'] = loss_face_giou.sum() / num_face_boxes | |
else: | |
losses['loss_face_bbox'] = loss_face_bbox.sum() * 0 | |
losses['loss_face_giou'] = loss_face_giou.sum() * 0 | |
return losses | |
def loss_dn_boxes(self, outputs, targets, indices, idx, num_boxes, | |
data_batch): | |
""" | |
Input: | |
- src_boxes: bs, num_dn, 4 | |
- tgt_boxes: bs, num_dn, 4 | |
""" | |
indices = indices[0] | |
num_tgt = outputs['num_tgt'] | |
src_boxes = outputs['dn_bbox_pred'] | |
tgt_boxes = outputs['dn_bbox_input'] | |
if 'num_tgt' not in outputs: | |
device = outputs['pred_logits'].device | |
losses = { | |
'dn_loss_bbox': src_boxes.sum()*0, | |
'dn_loss_giou': src_boxes.sum()*0, | |
} | |
return losses | |
if 'num_tgt' not in outputs: | |
device = outputs['pred_logits'].device | |
losses = { | |
'dn_loss_bbox': src_boxes.sum()*0, | |
'dn_loss_giou': src_boxes.sum()*0, | |
} | |
return losses | |
return self.tgt_loss_boxes(src_boxes, tgt_boxes, num_tgt) | |
def loss_dn_labels(self, outputs, targets, indices, idx, num_boxes, | |
data_batch): | |
""" | |
Input: | |
- src_logits: bs, num_dn, num_classes | |
- tgt_labels: bs, num_dn | |
""" | |
indices = indices[0] | |
if 'num_tgt' not in outputs: | |
device = outputs['pred_logits'].device | |
losses = { | |
'dn_loss_ce': outputs['pred_logits'].sum()*0, | |
} | |
return losses | |
num_tgt = outputs['num_tgt'] | |
src_logits = outputs['dn_class_pred'] # bs, num_dn, text_len | |
tgt_labels = outputs['dn_class_input'] | |
return self.tgt_loss_labels(src_logits, tgt_labels, num_tgt) | |
def loss_matching_cost(self, outputs, targets, indices, idx, num_boxes, | |
data_batch): | |
""" | |
Input: | |
- src_logits: bs, num_dn, num_classes | |
- tgt_labels: bs, num_dn | |
""" | |
cost_mean_dict = indices[1] | |
losses = {'set_{}'.format(k): v for k, v in cost_mean_dict.items()} | |
return losses | |
def _get_src_permutation_idx(self, indices): | |
# permute predictions following indices | |
batch_idx = torch.cat( | |
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
src_idx = torch.cat([src for (src, _) in indices]) | |
return batch_idx, src_idx | |
def _get_tgt_permutation_idx(self, indices): | |
# permute targets following indices | |
batch_idx = torch.cat( | |
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
return batch_idx, tgt_idx | |
def get_loss(self, loss, outputs, targets, data_batch, indices, num_boxes, | |
**kwargs): | |
loss_map = { | |
'smpl_pose': self.loss_smpl_pose, | |
'smpl_beta': self.loss_smpl_beta, | |
'smpl_expr': self.loss_smpl_expr, | |
'smpl_kp2d': self.loss_smpl_kp2d, | |
'smpl_kp2d_ba': self.loss_smpl_kp2d_ba, | |
'smpl_kp3d_ra': self.loss_smpl_kp3d_ra, | |
'smpl_kp3d': self.loss_smpl_kp3d, | |
'labels': self.loss_labels, | |
'cardinality': self.loss_cardinality, | |
'keypoints': self.loss_keypoints, | |
'boxes': self.loss_boxes, | |
'dn_label': self.loss_dn_labels, | |
'dn_bbox': self.loss_dn_boxes, | |
'matching': self.loss_matching_cost, | |
} | |
idx = self._get_src_permutation_idx(indices[0]) | |
# pdb.set_trace() | |
assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
return loss_map[loss](outputs, targets, indices, idx, num_boxes, | |
data_batch, **kwargs) | |
def prep_for_dn2(self, mask_dict): | |
known_bboxs = mask_dict['known_bboxs'] | |
known_labels = mask_dict['known_labels'] | |
output_known_coord = mask_dict['output_known_coord'] | |
output_known_class = mask_dict['output_known_class'] | |
num_tgt = mask_dict['pad_size'] | |
return known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt | |
## SMPL losses | |
def forward(self, outputs, targets, data_batch, return_indices=False): | |
""" This performs the loss computation. | |
Parameters: | |
outputs: dict of tensors, see the output specification of the model for the format | |
targets: list of dicts, such that len(targets) == batch_size. | |
The expected keys in each dict depends on the losses applied, see each loss' doc | |
return_indices: used for vis. if True, the layer0-5 indices will be returned as well. | |
""" | |
# import pdb; pdb.set_trace() | |
outputs_without_aux = { | |
k: v | |
for k, v in outputs.items() if k != 'aux_outputs' | |
} | |
device = next(iter(outputs.values())).device | |
# Compute the average number of target boxes accross all nodes, for normalization purposes | |
num_boxes = sum(len(t['boxes']) for t in targets) | |
num_boxes = torch.as_tensor([num_boxes], | |
dtype=torch.float, | |
device=device) | |
if is_dist_avail_and_initialized(): | |
torch.distributed.all_reduce(num_boxes) | |
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() | |
# loss for final layer | |
# pdb.set_trace() | |
indices = self.matcher(outputs_without_aux, targets, data_batch) | |
if return_indices: | |
indices0_copy = indices | |
indices_list = [] | |
losses = {} | |
smpl_loss = ['smpl_pose', 'smpl_beta', 'smpl_expr', 'smpl_kp2d', | |
'smpl_kp2d_ba', 'smpl_kp3d', 'smpl_kp3d_ra'] | |
# import pdb; pdb.set_trace() | |
for loss in self.losses: | |
# print(loss) | |
# print(self.get_loss(loss, outputs, targets, indices, num_boxes)) | |
kwargs = {} | |
if loss == 'keypoints' or loss in smpl_loss: | |
kwargs.update({'face_hand_kpt': True}) | |
if loss == 'boxes': | |
kwargs.update({'face_hand_box': True}) | |
losses.update( | |
self.get_loss( | |
loss, outputs, targets, | |
data_batch, indices, | |
num_boxes, **kwargs | |
)) | |
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
if 'aux_outputs' in outputs: | |
for idx, aux_outputs in enumerate(outputs['aux_outputs']): | |
indices = self.matcher(aux_outputs, targets, data_batch) | |
if return_indices: | |
indices_list.append(indices) | |
for loss in self.losses: | |
kwargs = {} | |
if loss == 'boxes': | |
kwargs.update({'face_hand_box': False}) | |
if idx >= self.num_box_decoder_layers: | |
kwargs.update({'face_hand_box': True}) | |
if loss == 'masks': | |
continue | |
if loss == 'keypoints': | |
if idx < self.num_box_decoder_layers: | |
continue | |
elif idx < self.num_hand_face_decoder_layers: | |
kwargs.update({'face_hand_kpt': False}) | |
else: | |
kwargs.update({'face_hand_kpt': True}) | |
if loss in smpl_loss: | |
if idx < self.num_box_decoder_layers: | |
continue | |
elif idx < self.num_hand_face_decoder_layers: | |
kwargs.update({'face_hand_kpt': False}) | |
else: | |
kwargs.update({'face_hand_kpt': True}) | |
if loss == 'labels': | |
# Logging is enabled only for the last layer | |
kwargs = {'log': False} | |
# if loss == 'smpl_expr' and idx < self.num_box_decoder_layers: | |
# continue | |
# import pdb;pdb.set_trace() | |
l_dict = self.get_loss(loss, aux_outputs, targets, | |
data_batch, indices, num_boxes, | |
**kwargs) | |
l_dict = {k + f'_{idx}': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
# interm_outputs loss | |
if 'interm_outputs' in outputs: | |
interm_outputs = outputs['interm_outputs'] | |
indices = self.matcher(interm_outputs, targets) | |
if return_indices: | |
indices_list.append(indices) | |
for loss in self.losses: | |
if loss in ['dn_bbox', 'dn_label', 'keypoints']: | |
continue | |
if loss in [ | |
'smpl_pose', 'smpl_beta', 'smpl_kp2d_ba', 'smpl_kp2d', | |
'smpl_kp3d_ra', 'smpl_kp3d', 'smpl_expr' | |
]: | |
continue | |
kwargs = {} | |
if loss == 'labels': | |
kwargs = {'log': False} | |
l_dict = self.get_loss(loss, interm_outputs, targets, | |
data_batch, indices, num_boxes, | |
**kwargs) | |
l_dict = {k + f'_interm': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
# aux_init loss | |
if 'query_expand' in outputs: | |
interm_outputs = outputs['query_expand'] | |
indices = self.matcher(interm_outputs, targets) | |
if return_indices: | |
indices_list.append(indices) | |
for loss in self.losses: | |
if loss in ['dn_bbox', 'dn_label']: | |
continue | |
kwargs = {} | |
if loss == 'labels': | |
kwargs = {'log': False} | |
l_dict = self.get_loss(loss, interm_outputs, targets, | |
data_batch, indices, num_boxes, | |
**kwargs) | |
l_dict = {k + f'_query_expand': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
if return_indices: | |
indices_list.append(indices0_copy) | |
return losses, indices_list | |
return losses | |
def tgt_loss_boxes( | |
self, | |
src_boxes, | |
tgt_boxes, | |
num_tgt, | |
): | |
""" | |
Input: | |
- src_boxes: bs, num_dn, 4 | |
- tgt_boxes: bs, num_dn, 4 | |
""" | |
loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none') | |
losses = {} | |
losses['dn_loss_bbox'] = loss_bbox.sum() / num_tgt | |
loss_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_boxes.flatten(0, 1)), | |
box_ops.box_cxcywh_to_xyxy(tgt_boxes.flatten(0, 1)))) | |
losses['dn_loss_giou'] = loss_giou.sum() / num_tgt | |
return losses | |
def tgt_loss_labels(self, | |
src_logits: Tensor, | |
tgt_labels: Tensor, | |
num_tgt: int, | |
log: bool = True): | |
""" | |
Input: | |
- src_logits: bs, num_dn, num_classes | |
- tgt_labels: bs, num_dn | |
""" | |
target_classes_onehot = torch.zeros([ | |
src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 | |
], | |
dtype=src_logits.dtype, | |
layout=src_logits.layout, | |
device=src_logits.device) | |
target_classes_onehot.scatter_(2, tgt_labels.unsqueeze(-1), 1) | |
target_classes_onehot = target_classes_onehot[:, :, :-1] | |
loss_ce = sigmoid_focal_loss(src_logits, | |
target_classes_onehot, | |
num_tgt, | |
alpha=self.focal_alpha, | |
gamma=2) * src_logits.shape[1] | |
losses = {'dn_loss_ce': loss_ce} | |
return losses | |
class SetCriterion_Box(nn.Module): | |
def __init__(self, | |
num_classes, | |
matcher, | |
weight_dict, | |
focal_alpha, | |
losses, | |
num_box_decoder_layers=2, | |
num_hand_face_decoder_layers=4, | |
num_body_points=17, | |
num_hand_points=6, | |
num_face_points=6, | |
smpl_loss_config=None, | |
convention='smplx_137'): | |
super().__init__() | |
self.num_classes = num_classes | |
self.matcher = matcher | |
self.weight_dict = weight_dict | |
self.losses = losses | |
self.focal_alpha = focal_alpha | |
self.vis = 0.1 | |
self.abs = 1 | |
self.num_body_points = 0 | |
self.num_hand_points = 0 | |
self.num_face_points = 0 | |
self.num_box_decoder_layers = num_box_decoder_layers | |
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
self.convention = convention | |
def loss_labels(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
log=True): | |
"""Classification loss (Binary focal loss) targets dicts must contain | |
the key "labels" containing a tensor of dim [nb_target_boxes]""" | |
indices = indices[0] | |
valid_num = 0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
assert 'pred_logits' in outputs | |
src_logits = outputs['pred_logits'] | |
target_classes_o = torch.cat( | |
[t['labels'][J] for t, (_, J) in zip(targets, indices)]) | |
target_classes = torch.full(src_logits.shape[:2], | |
self.num_classes, | |
dtype=torch.int64, | |
device=src_logits.device) | |
if valid_num == 0: | |
return {'loss_ce': src_logits.sum()*0} | |
target_classes[idx] = target_classes_o | |
target_classes_onehot = torch.zeros([ | |
src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 | |
], | |
dtype=src_logits.dtype, | |
layout=src_logits.layout, | |
device=src_logits.device) | |
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) | |
target_classes_onehot = target_classes_onehot[:, :, :-1] | |
loss_ce = sigmoid_focal_loss(src_logits, | |
target_classes_onehot, | |
num_boxes, | |
alpha=self.focal_alpha, | |
gamma=2) * src_logits.shape[1] | |
losses = {'loss_ce': loss_ce} | |
if log: | |
# TODO this should probably be a separate loss, not hacked in this one here | |
losses['class_error'] = 100 - accuracy(src_logits[idx], | |
target_classes_o)[0] | |
return losses | |
def loss_cardinality(self, outputs, targets, indices, num_boxes, | |
data_batch): | |
"""Compute the cardinality error, ie the absolute error in the number | |
of predicted non-empty boxes This is not really a loss, it is intended | |
for logging purposes only. | |
It doesn't propagate gradients | |
""" | |
pred_logits = outputs['pred_logits'] | |
device = pred_logits.device | |
tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets], | |
device=device) | |
if tgt_lengths == 0: | |
return {'cardinality_error': pred_logits.sum()*0} | |
# Count the number of predictions that are NOT "no-object" (which is the last class) | |
card_pred = (pred_logits.argmax(-1) != | |
pred_logits.shape[-1] - 1).sum(1) | |
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) | |
losses = {'cardinality_error': card_err} | |
return losses | |
def loss_smpl_pose(self, outputs, targets, indices, idx, num_boxes, | |
data_batch, face_hand_kpt=False): | |
indices = indices[0] | |
device = outputs['pred_logits'].device | |
# import pdb | |
# pdb.set_trace() | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
pred_smpl_body_pose = outputs['pred_smpl_pose'][idx] # 22 | |
pred_smpl_lhand_pose = outputs['pred_smpl_lhand_pose'][idx] # 15 | |
pred_smpl_rhand_pose = outputs['pred_smpl_rhand_pose'][idx] # 15 | |
pred_smpl_jaw_pose = outputs['pred_smpl_jaw_pose'][idx] | |
pred_smplx_pose = torch.cat((pred_smpl_body_pose, pred_smpl_lhand_pose, | |
pred_smpl_rhand_pose, pred_smpl_jaw_pose), | |
dim=1) | |
targets_smpl_pose = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['smplx_pose'], indices)], | |
dim=0) | |
targets_smpl_pose = batch_rodrigues(targets_smpl_pose.view( | |
-1, 3)).view(-1, 53, 3, 3) | |
conf = torch.cat([ | |
t[i] for t, (_, i) in zip(data_batch['smplx_pose_valid'], indices) | |
], dim=0) | |
conf = (conf.reshape(-1,53,3)[:,:,:,None]).repeat(1,1,1,3) | |
losses = {} | |
if valid_num == 0: | |
losses['loss_smpl_pose_root'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
losses['loss_smpl_pose_body'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
losses['loss_smpl_pose_lhand'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
losses['loss_smpl_pose_rhand'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
losses['loss_smpl_pose_jaw'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
return losses | |
# valid_pos = conf > 0 | |
if conf.sum() == 0: | |
losses['loss_smpl_pose_root'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
losses['loss_smpl_pose_body'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
losses['loss_smpl_pose_lhand'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
losses['loss_smpl_pose_rhand'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
losses['loss_smpl_pose_jaw'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 | |
return losses | |
loss_smpl_pose = \ | |
F.l1_loss( | |
pred_smplx_pose, | |
targets_smpl_pose, | |
reduction='none' | |
) | |
# pdb.set_trace() | |
loss_smpl_pose = loss_smpl_pose * conf | |
loss_smpl_pose = loss_smpl_pose.sum([-1,-2]) | |
# loss_smpl_pose[:,0] = loss_smpl_pose[:,0]*5 | |
if face_hand_kpt: | |
losses = { | |
'loss_smpl_pose_root': loss_smpl_pose[:, 0].sum() / num_boxes, | |
'loss_smpl_pose_body': loss_smpl_pose[:, 1:22].sum() / num_boxes, | |
'loss_smpl_pose_lhand': loss_smpl_pose[:, 22:37].sum() / num_boxes, | |
'loss_smpl_pose_rhand': loss_smpl_pose[:, 37:52].sum() / num_boxes, | |
'loss_smpl_pose_jaw': loss_smpl_pose[:, 52].sum() / num_boxes, | |
} | |
else: | |
losses = { | |
'loss_smpl_pose_root': loss_smpl_pose[:, 0].sum() / num_boxes, | |
'loss_smpl_pose_body': loss_smpl_pose[:, 1:22].sum() / num_boxes, | |
'loss_smpl_pose_lhand': 0 * loss_smpl_pose[:, 22:37].sum() / num_boxes, | |
'loss_smpl_pose_rhand': 0 * loss_smpl_pose[:, 37:52].sum() / num_boxes, | |
'loss_smpl_pose_jaw': loss_smpl_pose[:, 52].sum() / num_boxes, | |
} | |
# losses = {'loss_smpl_pose': loss_smpl_pose.sum() / num_boxes} | |
return losses | |
def loss_smpl_beta(self, outputs, targets, indices, idx, num_boxes, | |
data_batch, face_hand_kpt=False): | |
indices = indices[0] | |
device = outputs['pred_logits'].device | |
# import pdb | |
# pdb.set_trace() | |
pred_smpl_betas = outputs['pred_smpl_beta'][idx] | |
targets_smpl_betas = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['smplx_shape'], indices)], | |
dim=0) | |
# import pdb | |
# pdb.set_trace() | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
losses = {} | |
if valid_num == 0: | |
losses['loss_smpl_beta'] = torch.as_tensor(0., device=device) + pred_smpl_betas.sum() * 0 | |
return losses | |
conf = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_shape_valid'], indices)], dim=0) | |
# valid_pos = conf > 0 | |
if conf.sum() == 0: | |
return { | |
'loss_smpl_beta': torch.as_tensor(0., device=device) + pred_smpl_betas.sum() * 0 | |
} | |
loss_smpl_betas = \ | |
F.l1_loss( | |
pred_smpl_betas, | |
targets_smpl_betas, | |
reduction='none' | |
) | |
# pdb.set_trace() | |
loss_smpl_betas = loss_smpl_betas.sum(-1) * conf | |
losses = {'loss_smpl_beta': loss_smpl_betas.sum() / num_boxes} | |
return losses | |
def loss_smpl_expr(self, outputs, targets, indices, idx, num_boxes, | |
data_batch, face_hand_kpt=False): | |
indices = indices[0] | |
device = outputs['pred_logits'].device | |
pred_smpl_expr = outputs['pred_smpl_expr'][idx] | |
# import pdb | |
# pdb.set_trace() | |
targets_smpl_expr = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr'], indices)], dim=0) | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
losses = {} | |
if valid_num == 0: | |
losses['loss_smpl_expr'] = torch.as_tensor(0., device=device) + pred_smpl_expr.sum() * 0 | |
return losses | |
conf = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr_valid'], indices)], dim=0) | |
# valid_pos = conf > 0 | |
if conf.sum() == 0: | |
return { | |
'loss_smpl_expr': torch.as_tensor(0., device=device) + pred_smpl_expr.sum() * 0 | |
} | |
loss_smpl_expr = \ | |
F.l1_loss( | |
pred_smpl_expr, | |
targets_smpl_expr, | |
reduction='none' | |
) | |
# pdb.set_trace() | |
loss_smpl_expr = loss_smpl_expr.sum(-1) * conf | |
if face_hand_kpt: | |
losses = {'loss_smpl_expr': loss_smpl_expr.sum() / (conf.sum() + 1e-6)} | |
else: | |
losses = {'loss_smpl_expr': 0*loss_smpl_expr.sum() / (conf.sum() + 1e-6) } | |
return losses | |
def loss_smpl_kp3d(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
has_keypoints3d=None, | |
face_hand_kpt=False): | |
# supervision for keypoints3d wo/ ra | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() | |
# meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) | |
targets_smpl_kp3d = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['joint_cam'], indices)], | |
dim=0) | |
losses = {} | |
if valid_num == 0: | |
losses['loss_smpl_body_kp3d'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
losses['loss_smpl_lhand_kp3d'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
losses['loss_smpl_rhand_kp3d'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
losses['loss_smpl_face_kp3d'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
return losses | |
targets_kp3d_conf = targets_smpl_kp3d[:,:,3:].clone() | |
targets_smpl_kp3d = targets_smpl_kp3d[:,:,:3] | |
targets_is_3d = torch.cat([ | |
t[None, None].repeat(len(i), 1, 1) | |
for t, (_, i) in zip(data_batch['is_3D'], indices) | |
], | |
dim=0) | |
targets_kp3d_conf = (targets_kp3d_conf * targets_is_3d).repeat(1, 1, 3) | |
pelvis_idx = get_keypoint_idx('pelvis', self.convention) | |
targets_pelvis = targets_smpl_kp3d[..., pelvis_idx, :] | |
pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] | |
targets_smpl_kp3d = targets_smpl_kp3d - targets_pelvis[:, None, :] | |
pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] | |
losses = {} | |
body_idx = smpl_x.joint_part['body'] | |
face_idx = smpl_x.joint_part['face'] | |
lhand_idx = smpl_x.joint_part['lhand'] | |
rhand_idx = smpl_x.joint_part['rhand'] | |
# currently, only mpi_inf_3dhp and h36m have 3d keypoints | |
# both datasets have right_hip_extra and left_hip_extra | |
loss_smpl_kp3d = F.l1_loss(pred_smpl_kp3d, | |
targets_smpl_kp3d, | |
reduction='none') | |
# If has_keypoints3d is not None, then computes the losses on the | |
# instances that have ground-truth keypoints3d. | |
# But the zero confidence keypoints will be included in mean. | |
# Otherwise, only compute the keypoints3d | |
# which have positive confidence. | |
# has_keypoints3d is None when the key has_keypoints3d | |
# is not in the datasets | |
valid_pos = targets_kp3d_conf > 0 | |
if targets_kp3d_conf[valid_pos].numel() == 0: | |
return { | |
'loss_smpl_body_kp3d': | |
torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0, | |
'loss_smpl_lhand_kp3d': | |
torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0, | |
'loss_smpl_rhand_kp3d': | |
torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0, | |
'loss_smpl_face_kp3d': | |
torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0, | |
} | |
loss_smpl_kp3d = loss_smpl_kp3d * targets_kp3d_conf | |
if face_hand_kpt: | |
losses['loss_smpl_body_kp3d'] = torch.sum(loss_smpl_kp3d[:, body_idx, :]) / num_boxes | |
losses['loss_smpl_lhand_kp3d'] = torch.sum(loss_smpl_kp3d[:, lhand_idx, :]) / num_boxes | |
losses['loss_smpl_rhand_kp3d'] = torch.sum(loss_smpl_kp3d[:, rhand_idx, :]) / num_boxes | |
losses['loss_smpl_face_kp3d'] = torch.sum(loss_smpl_kp3d[:, face_idx, :]) / num_boxes | |
else: | |
losses['loss_smpl_body_kp3d'] = torch.sum(loss_smpl_kp3d[:, body_idx, :]) / num_boxes | |
losses['loss_smpl_lhand_kp3d'] = 0*torch.sum(loss_smpl_kp3d[:, lhand_idx, :]) / num_boxes | |
losses['loss_smpl_rhand_kp3d'] = 0*torch.sum(loss_smpl_kp3d[:, rhand_idx, :]) /num_boxes | |
losses['loss_smpl_face_kp3d'] = 0*torch.sum(loss_smpl_kp3d[:, face_idx, :]) / num_boxes | |
return losses | |
def loss_smpl_kp3d_ra(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
has_keypoints3d=None, | |
face_hand_kpt=False): | |
# supervision for keypoints3d w/ ra | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() | |
# meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) | |
targets_smpl_kp3d = torch.cat([ | |
t[i] for t, (_, i) in zip(data_batch['smplx_joint_cam'], indices) | |
], | |
dim=0) | |
losses = {} | |
if valid_num == 0: | |
losses['loss_smpl_rhand_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
losses['loss_smpl_body_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
losses['loss_smpl_face_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
losses['loss_smpl_lhand_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 | |
return losses | |
targets_kp3d_conf = targets_smpl_kp3d[:,:,3:].clone() | |
targets_smpl_kp3d = targets_smpl_kp3d[:,:,:3] | |
targets_is_3d = torch.cat([ | |
t[None, None].repeat(len(i), 1, 1) | |
for t, (_, i) in zip(data_batch['is_3D'], indices) | |
], | |
dim=0) | |
targets_kp3d_conf = (targets_kp3d_conf * targets_is_3d).repeat(1, 1, 3) | |
targets_smpl_kp3d = targets_smpl_kp3d[..., :3].float() | |
pelvis_idx = get_keypoint_idx('pelvis', self.convention) | |
targets_pelvis = targets_smpl_kp3d[..., pelvis_idx, :] | |
pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] | |
targets_smpl_kp3d = targets_smpl_kp3d - targets_pelvis[:, None, :] | |
pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] | |
# calculate body, face and hand loss separately: | |
losses = {} | |
body_idx = smpl_x.joint_part['body'] | |
face_idx = smpl_x.joint_part['face'] | |
lhand_idx = smpl_x.joint_part['lhand'] | |
rhand_idx = smpl_x.joint_part['rhand'] | |
loss_smpl_body_kp3d = F.l1_loss(pred_smpl_kp3d[:, body_idx, :], | |
targets_smpl_kp3d[:, body_idx, :], | |
reduction='none') | |
loss_smpl_body_kp3d = torch.sum( | |
loss_smpl_body_kp3d * targets_kp3d_conf[:, body_idx, :]) | |
losses['loss_smpl_body_kp3d_ra'] = loss_smpl_body_kp3d / num_boxes | |
# if face_hand_kpt: | |
face_cam = pred_smpl_kp3d[:, face_idx, :] | |
neck_cam = pred_smpl_kp3d[:, smpl_x.neck_idx, None, :] | |
face_cam = face_cam - neck_cam | |
loss_smpl_face_kp3d = F.l1_loss(face_cam, | |
targets_smpl_kp3d[:, face_idx, :], | |
reduction='none') | |
loss_smpl_face_kp3d = torch.sum( | |
loss_smpl_face_kp3d * targets_kp3d_conf[:, face_idx, :]) | |
if face_hand_kpt: | |
losses['loss_smpl_face_kp3d_ra'] = (loss_smpl_face_kp3d / num_boxes) | |
else: | |
losses['loss_smpl_face_kp3d_ra'] = 0*(loss_smpl_face_kp3d / num_boxes) | |
lhand_cam = pred_smpl_kp3d[:, lhand_idx, :] | |
lwrist_cam = pred_smpl_kp3d[:, smpl_x.lwrist_idx, None, :] | |
lhand_cam = lhand_cam - lwrist_cam | |
loss_smpl_lhand_kp3d = F.l1_loss(lhand_cam, | |
targets_smpl_kp3d[:, lhand_idx, :], | |
reduction='none') | |
loss_smpl_lhand_kp3d = torch.sum( | |
loss_smpl_lhand_kp3d * targets_kp3d_conf[:, lhand_idx, :]) | |
if face_hand_kpt: | |
losses['loss_smpl_lhand_kp3d_ra'] = (loss_smpl_lhand_kp3d / num_boxes) | |
else: | |
losses['loss_smpl_lhand_kp3d_ra'] = 0*(loss_smpl_lhand_kp3d /num_boxes) | |
rhand_cam = pred_smpl_kp3d[:, rhand_idx, :] | |
rwrist_cam = pred_smpl_kp3d[:, smpl_x.rwrist_idx, None, :] | |
rhand_cam = rhand_cam - rwrist_cam | |
loss_smpl_rhand_kp3d = F.l1_loss(rhand_cam, | |
targets_smpl_kp3d[:, rhand_idx, :], | |
reduction='none') | |
loss_smpl_rhand_kp3d = torch.sum( | |
loss_smpl_rhand_kp3d * targets_kp3d_conf[:, rhand_idx, :]) | |
if face_hand_kpt: | |
losses['loss_smpl_rhand_kp3d_ra'] = (loss_smpl_rhand_kp3d / num_boxes) | |
else: | |
losses['loss_smpl_rhand_kp3d_ra'] = 0*(loss_smpl_rhand_kp3d / num_boxes) | |
return losses | |
def loss_smpl_kp2d(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
focal_length=5000., | |
has_keypoints2d=None, | |
face_hand_kpt=False): | |
"""Compute loss for 2d keypoints.""" | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
# pdb.set_trace() | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float()#.detach() | |
# pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() | |
# pelvis_idx = get_keypoint_idx('pelvis', self.convention) | |
# pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] | |
# pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] +1e-7 | |
pred_cam = outputs['pred_smpl_cam'][idx].float() | |
targets_kp2d = torch.cat([t[i] for t, (_, i) in zip(data_batch['joint_img'], indices)], dim=0) | |
keypoints2d_conf = targets_kp2d[:,:,2:].clone() | |
targets_kp2d = targets_kp2d[:,:,:2] | |
target_lhand_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) | |
lhand_num_boxes = target_lhand_boxes_conf.sum() | |
target_rhand_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices)], dim=0) | |
rhand_num_boxes = target_rhand_boxes_conf.sum() | |
target_face_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) | |
face_num_boxes = target_face_boxes_conf.sum() | |
# t_pose = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_pose'], indices)], dim=0) | |
# t_shape = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_shape'], indices)], dim=0) | |
# t_expr = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr'], indices)], dim=0) | |
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) | |
targets_kp2d = targets_kp2d[:, :, :2].float() | |
targets_kp2d[:,:,0] = targets_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
targets_kp2d[:,:,1] = targets_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
# targets_kp2d = targets_kp2d*2-1 | |
img_wh = torch.cat([data_batch['img_shape'][i][None] for i in idx[0]], dim=0).flip(-1) | |
# pred_smpl_kp2d = weak_perspective_projection(pred_smpl_kp3d, scale=pred_cam[:, 0], translation=pred_cam[:, 1:3]) | |
# If kp2ds is normalized to [-1, 1], the center should be the center of the image; | |
# if normalized to 0-1, it should be at the top left corner (0, 0)? | |
pred_smpl_kp2d = project_points_new( | |
points_3d=pred_smpl_kp3d, | |
pred_cam=pred_cam, | |
focal_length=focal_length, | |
camera_center=img_wh/2 | |
) | |
pred_smpl_kp2d = pred_smpl_kp2d / img_wh[:, None] | |
vis=False | |
# if 'vis' in cfg: | |
# vis=cfg['vis'] | |
# vis = True | |
if vis: | |
import mmcv | |
import cv2 | |
import numpy as np | |
from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl | |
from detrsmpl.models.body_models.builder import build_body_model | |
from pytorch3d.io import save_obj | |
from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d | |
img = mmcv.imdenormalize( | |
img=(data_batch['img'][0].cpu().numpy()).transpose(1, 2, 0), | |
mean=np.array([123.675, 116.28, 103.53]), | |
std=np.array([58.395, 57.12, 57.375]), | |
to_bgr=True).astype(np.uint8) | |
cv2.imwrite('test.png', img) | |
device = outputs['pred_smpl_kp3d'].device | |
body_model = dict( | |
type='smplx', | |
keypoint_src='smplx', | |
num_expression_coeffs=10, | |
num_betas=10, | |
keypoint_dst='smplx_137', | |
model_path='data/body_models/smplx', | |
use_pca=False, | |
use_face_contour=True) | |
bm = build_body_model(body_model).to(device) | |
pred_smpl_body_pose = rotmat_to_aa(outputs['pred_smpl_pose'][idx]) | |
pred_smpl_lhand_pose = rotmat_to_aa(outputs['pred_smpl_lhand_pose'][idx]) | |
pred_smpl_rhand_pose = rotmat_to_aa(outputs['pred_smpl_rhand_pose'][idx]) | |
pred_smpl_jaw_pose = rotmat_to_aa(outputs['pred_smpl_jaw_pose'][idx]) | |
pred_smpl_shape = outputs['pred_smpl_beta'][idx] | |
pred_output = bm( | |
betas=pred_smpl_shape.reshape(-1, 10), | |
body_pose=pred_smpl_body_pose[:,1:].reshape(-1, 21*3), | |
global_orient=pred_smpl_body_pose[:,:1].reshape(-1, 3), | |
left_hand_pose=pred_smpl_lhand_pose.reshape(-1, 15*3), | |
right_hand_pose=pred_smpl_rhand_pose.reshape(-1, 15*3), | |
leye_pose=torch.zeros_like(pred_smpl_jaw_pose).reshape(-1, 3), | |
reye_pose=torch.zeros_like(pred_smpl_jaw_pose).reshape(-1, 3), | |
expression=torch.zeros_like(pred_smpl_shape).reshape(-1, 10), | |
jaw_pose=pred_smpl_jaw_pose.reshape(-1, 3)) | |
verts = pred_output['vertices'] | |
# for i_obj,v in enumerate(verts): | |
# save_obj('./figs/pred_smpl_%d.obj'%i_obj,verts = v,faces=torch.tensor([])) | |
pred_cam = outputs['pred_smpl_cam'][idx] | |
targets_smpl_pose = data_batch['smplx_pose'][0] | |
targets_shape = data_batch['smplx_shape'][0] | |
gt_kp3d = data_batch['joint_cam'][0] | |
gt_kp2d = data_batch['joint_img'][0] | |
gt_body_boxes = torch.cat( | |
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
# gt kp3d | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() | |
visualize_kp3d(gt_kp3d.detach().cpu().numpy(), | |
output_path='./figs/gt3d', | |
data_source='smplx_137') | |
# visualize_kp3d(pred_smpl_kp3d.detach().cpu().numpy(), | |
# output_path='./figs/pred3d', | |
# data_source='smplx_137') | |
# gt kp2d | |
img =(data_batch['img'][0].permute(1,2,0)*255).int().cpu().numpy() | |
gt_2d= gt_kp2d.detach().cpu().numpy()[...,:2]*data_batch['img_shape'].cpu().numpy()[0,None,None,::-1] | |
gt_2d[...,0] = gt_2d[...,0]/12 | |
gt_2d[...,1] = gt_2d[...,1]/16 | |
import mmcv | |
batch_id = 0 | |
gt_bbox = (box_ops.box_cxcywh_to_xyxy(targets[batch_id]['boxes']).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[batch_id, ::-1]).reshape(-1,4) | |
gt_bbox_lhand = (box_ops.box_cxcywh_to_xyxy(targets[batch_id]['lhand_boxes']).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[batch_id, ::-1]).reshape(-1,4) | |
gt_bbox_rhand = (box_ops.box_cxcywh_to_xyxy(targets[batch_id]['rhand_boxes']).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[batch_id, ::-1]).reshape(-1,4) | |
gt_bbox_face = (box_ops.box_cxcywh_to_xyxy(targets[batch_id]['face_boxes']).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[batch_id, ::-1]).reshape(-1,4) | |
gt_bbox = np.concatenate([gt_bbox,gt_bbox_face,gt_bbox_rhand,gt_bbox_lhand],axis=0) | |
# gt_bbox = (box_ops.box_cxcywh_to_xyxy(gt_body_boxes).reshape(-1,2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1][None,None,:]).reshape(-1,4) | |
img = mmcv.imshow_bboxes(img.copy(), gt_bbox, show=False) | |
gt_2d = data_batch['joint_img'][0][:,:,:2].cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0,None,None,::-1]# *data_batch['joint_img'][0][:,:,2:].cpu().numpy() | |
gt_2d[...,0] = gt_2d[...,0]/12 | |
gt_2d[...,1] = gt_2d[...,1]/16 | |
# data_batch['joint_img'] | |
# gt_kp2d = gt_2d[0][keypoints2d_conf[0]!=0] | |
visualize_kp2d( | |
(gt_2d).reshape(-1,2)[None], | |
output_path='./figs/gt2d', | |
image_array=img.copy()[None], | |
# data_source='smplx_137', | |
disable_limbs = True, | |
overwrite=True) | |
img =(data_batch['img'][0].permute(1,2,0)*255).int().cpu().numpy() | |
# pred_smpl_kp2d = project_points_new( | |
# points_3d=outputs['pred_smpl_kp3d'][:,:2].reshape(-1,137,3), | |
# pred_cam=pred_cam, | |
# focal_length=focal_length, | |
# camera_center=img_wh/2 | |
# ) | |
img_shape = data_batch['img_shape'][0] | |
# pred_kp2d = pred_kp2d.cpu().detach().numpy()*img_shape.cpu().numpy()[None,None ::-1] | |
# pred_bbox_all = [] | |
# for i in idx[0]: | |
# pred_bbox_body = (box_ops.box_cxcywh_to_xyxy(outputs['pred_boxes'][0,i]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) | |
# pred_bbox_lhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_lhand_boxes'][0,i]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) | |
# pred_bbox_rhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_rhand_boxes'][0,i]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) | |
# pred_bbox_face = (box_ops.box_cxcywh_to_xyxy(outputs['pred_face_boxes'][0,i]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) | |
# pred_bbox = np.concatenate([pred_bbox_body,pred_bbox_face,pred_bbox_rhand,pred_bbox_lhand],axis=0) | |
# pred_bbox_all.append(pred_bbox) | |
# src_body_boxes = outputs['pred_boxes'][idx] | |
# pred_bbox_all = np.concatenate(pred_bbox_all,axis=0) | |
pred_bbox_body = (box_ops.box_cxcywh_to_xyxy(outputs['pred_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) | |
pred_bbox_lhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_lhand_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) | |
pred_bbox_rhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_rhand_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) | |
pred_bbox_face = (box_ops.box_cxcywh_to_xyxy(outputs['pred_face_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) | |
pred_bbox = np.concatenate([pred_bbox_body,pred_bbox_face,pred_bbox_rhand,pred_bbox_lhand],axis=0) | |
# pred_bbox_body = (box_ops.box_cxcywh_to_xyxy(src_body_boxes).reshape(-1,2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1][None,None,:]).reshape(-1,4) | |
# import ipdb;ipdb.set_trace() | |
img = mmcv.imshow_bboxes(img.copy(), pred_bbox, show=False) | |
# cv2.imwrite('test.png',img) | |
visualize_kp2d( | |
(pred_smpl_kp2d*img_wh[:, None])[None].detach().cpu().numpy(), | |
output_path='./figs/pred2d', | |
image_array=img.copy()[None], | |
data_source='smplx_137', | |
overwrite=True) | |
# visualize_kp2d( | |
# (pred_smpl_kp2d*img_wh[:, None])[None].detach().cpu().numpy(), | |
# output_path='./figs/pred2d', | |
# image_array=img.copy()[None], | |
# data_source='smplx_137', | |
# overwrite=True) | |
vis_smpl=True | |
if vis_smpl: | |
gt_output = bm( | |
betas=targets_shape.reshape(-1, 10), | |
body_pose=targets_smpl_pose[:,3:66].reshape(-1, 21*3), | |
global_orient=targets_smpl_pose[:,:3].reshape(-1, 3), | |
left_hand_pose=targets_smpl_pose[:,66:111].reshape(-1, 15*3), | |
right_hand_pose=targets_smpl_pose[:,111:156].reshape(-1, 15*3), | |
leye_pose=torch.zeros_like(targets_smpl_pose[:,:3]).reshape(-1, 3), | |
reye_pose=torch.zeros_like(targets_smpl_pose[:,:3]).reshape(-1, 3), | |
expression=torch.zeros_like(targets_shape).reshape(-1, 10), | |
jaw_pose=targets_smpl_pose[:,156:].reshape(-1, 3)) | |
verts = gt_output['vertices'] | |
for i_obj,v in enumerate(verts): | |
save_obj('./figs/gt_smpl_%d.obj'%i_obj,verts = v,faces=torch.tensor([])) | |
import ipdb;ipdb.set_trace() | |
losses = {} | |
if valid_num == 0: | |
losses['loss_smpl_body_kp2d'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_lhand_kp2d'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_rhand_kp2d'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_face_kp2d'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
return losses | |
body_idx = smpl_x.joint_part['body'] | |
face_idx = smpl_x.joint_part['face'] | |
lhand_idx = smpl_x.joint_part['lhand'] | |
rhand_idx = smpl_x.joint_part['rhand'] | |
loss_smpl_kp2d = F.l1_loss(pred_smpl_kp2d, | |
targets_kp2d, | |
reduction='none') | |
# If has_keypoints2d is not None, then computes the losses on the | |
# instances that have ground-truth keypoints2d. | |
# But the zero confidence keypoints will be included in mean. | |
# Otherwise, only compute the keypoints2d | |
# which have positive confidence. | |
# has_keypoints2d is None when the key has_keypoints2d | |
# is not in the datasets | |
# import pdb; pdb.set_trace() | |
valid_pos = keypoints2d_conf > 0 | |
if keypoints2d_conf[valid_pos].numel() == 0: | |
return { | |
'loss_smpl_body_kp2d': torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0, | |
'loss_smpl_lhand_kp2d': torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0, | |
'loss_smpl_rhand_kp2d': torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0, | |
'loss_smpl_face_kp2d': torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0, | |
} | |
loss_smpl_kp2d = loss_smpl_kp2d * keypoints2d_conf | |
# loss /= keypoints2d_conf[valid_pos].numel() | |
if face_hand_kpt: | |
losses['loss_smpl_body_kp2d'] = torch.sum(loss_smpl_kp2d[:, body_idx, :]) / num_boxes | |
if lhand_num_boxes>0: | |
losses['loss_smpl_lhand_kp2d'] = torch.sum(loss_smpl_kp2d[:, lhand_idx, :]) / lhand_num_boxes | |
else: | |
losses['loss_smpl_lhand_kp2d'] =torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0 | |
if rhand_num_boxes>0: | |
losses['loss_smpl_rhand_kp2d'] = torch.sum(loss_smpl_kp2d[:, rhand_idx, :]) / rhand_num_boxes | |
else: | |
losses['loss_smpl_rhand_kp2d'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0 | |
if face_num_boxes>0: | |
losses['loss_smpl_face_kp2d'] = torch.sum(loss_smpl_kp2d[:, face_idx, :]) / face_num_boxes | |
else: | |
losses['loss_smpl_face_kp2d'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0 | |
else: | |
losses['loss_smpl_body_kp2d'] = torch.sum(loss_smpl_kp2d[:, body_idx, :]) / num_boxes | |
losses['loss_smpl_lhand_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, lhand_idx, :]) / (keypoints2d_conf[:, lhand_idx].sum() + 1e-6) | |
losses['loss_smpl_rhand_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, rhand_idx, :]) / (keypoints2d_conf[:, rhand_idx].sum() + 1e-6) | |
losses['loss_smpl_face_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, face_idx, :]) / (keypoints2d_conf[:, face_idx].sum() + 1e-6) | |
return losses | |
def loss_smpl_kp2d_ba(self, | |
outputs, | |
targets, | |
indices, | |
idx, | |
num_boxes, | |
data_batch, | |
focal_length=5000., | |
has_keypoints2d=None, | |
face_hand_kpt=False): | |
"""Compute loss for 2d keypoints.""" | |
device = outputs['pred_logits'].device | |
indices = indices[0] | |
# pdb.set_trace() | |
pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float()#.detach() | |
pred_cam = outputs['pred_smpl_cam'][idx].float() | |
# pdb.set_trace() | |
# max_img_res = orig_img_res.max(-1)[0] | |
# torch.cat([ torch.Tensor([orig_img_res[0]]*9), torch.Tensor([orig_img_res[1]]*9)], 0) | |
# torch.cat([orig_img_res[i][None].repeat(num,1) for i, num in enumerate(instance_num)], 0) | |
# orig_img_res = torch.Tensor([t['orig_size'] for t, (_, i) in zip(targets, indices)]).type_as(pred_smpl_kp3d) | |
# orig_img_res = torch.Tensor([target['orig_size'] for target in targets]).type_as(pred_smpl_kp3d) | |
# max_img_res = torch.cat([torch.full_like(src, i) for i, (src, _) in zip(max_img_res, indices)]).type_as(pred_smpl_kp3d) | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
targets_kp2d = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['joint_img'], indices)], | |
dim=0) | |
losses = {} | |
keypoints2d_conf = targets_kp2d[:,:,2:].clone() | |
targets_kp2d = targets_kp2d[:,:,:2] | |
keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) | |
targets_kp2d = targets_kp2d[:, :, :2].float() | |
targets_kp2d[:, :, 0] = targets_kp2d[:, :, 0] / cfg.output_hm_shape[2] | |
targets_kp2d[:, :, 1] = targets_kp2d[:, :, 1] / cfg.output_hm_shape[1] | |
# targets_kp2d = targets_kp2d * 2 - 1 | |
img_wh = torch.cat([data_batch['img_shape'][i][None] for i in idx[0]], dim=0).flip(-1) | |
pred_smpl_kp2d = project_points_new( | |
points_3d=pred_smpl_kp3d, | |
pred_cam=pred_cam, | |
focal_length=focal_length, | |
camera_center=img_wh/2 | |
) | |
pred_smpl_kp2d = pred_smpl_kp2d / img_wh[:, None] | |
if valid_num == 0: | |
losses['loss_smpl_body_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_lhand_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_rhand_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
losses['loss_smpl_face_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 | |
return losses | |
# rhand bbox | |
rhand_bbox_valid = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices) ], dim=0) | |
rhand_bbox_gt = torch.cat( | |
[t['rhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
rhand_bbox_gt = (box_ops.box_cxcywh_to_xyxy(rhand_bbox_gt). | |
reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) | |
num_rhand_bbox = rhand_bbox_valid.sum() | |
# lhand bbox | |
lhand_bbox_valid = torch.cat([ | |
t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) | |
lhand_bbox_gt = torch.cat( | |
[t['lhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
lhand_bbox_gt = (box_ops.box_cxcywh_to_xyxy(lhand_bbox_gt). | |
reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) | |
num_lhand_bbox = lhand_bbox_valid.sum() | |
# face bbox | |
face_bbox_valid = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) | |
face_bbox_gt = torch.cat( | |
[t['face_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
face_bbox_gt = (box_ops.box_cxcywh_to_xyxy(face_bbox_gt). | |
reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) | |
num_face_bbox = face_bbox_valid.sum() | |
img_shape = torch.cat( | |
[t[None].repeat(len(i), 1) for t, (_, i) in zip(data_batch['img_shape'], indices)], dim=0) | |
# joint_proj = (joint_proj / 2 + 0.5) | |
# joint_proj[:, :, 0] = joint_proj[:, :, 0] * img_shape[:, 1:] | |
# joint_proj[:, :, 1] = joint_proj[:, :, 1] * img_shape[:, :1] | |
if not (lhand_bbox_valid + rhand_bbox_valid + face_bbox_valid == 0).all(): | |
for part_name, bbox in ( | |
('lhand', lhand_bbox_gt), | |
('rhand', rhand_bbox_gt), | |
('face', face_bbox_gt)): | |
x = targets_kp2d[:, smpl_x.joint_part[part_name], 0] | |
y = targets_kp2d[:, smpl_x.joint_part[part_name], 1] | |
# trunc = joint_trunc[:, smpl_x.joint_part[part_name], 0] | |
trunc = keypoints2d_conf[:, smpl_x.joint_part[part_name], 0].clone() | |
# x in [0, 1]? bbox in [0, 1]. | |
x -= (bbox[:, None, 0] / img_shape[:, 1:]) | |
# x | |
x *= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0] + 1e-6)) | |
y -= (bbox[:, None, 1] / img_shape[:, :1]) | |
y *= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1] + 1e-6)) | |
# transformed to 0-1 bbox space | |
trunc *= ((x >= 0) * (x <= 1) * | |
(y >= 0) * (y <= 1)) | |
coord = torch.stack((x, y), 2) | |
targets_kp2d = torch.cat( | |
(targets_kp2d[:, :smpl_x.joint_part[part_name][0], :], coord, | |
targets_kp2d[:, smpl_x.joint_part[part_name][-1] + 1:, :]), | |
1) | |
x_pred = pred_smpl_kp2d[:, smpl_x.joint_part[part_name], 0] | |
y_pred = pred_smpl_kp2d[:, smpl_x.joint_part[part_name], 1] | |
# bbox: xyxy img_shape: hw | |
x_pred -= (bbox[:, None, 0] / img_shape[:, 1:]) | |
x_pred *= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0] + 1e-6)) | |
y_pred -= (bbox[:, None, 1] / img_shape[:, :1]) | |
y_pred *= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1] + 1e-6)) | |
coord_pred = torch.stack((x_pred, y_pred), 2) | |
trans = [] | |
for bid in range(coord_pred.shape[0]): | |
mask = trunc[bid] == 1 | |
if torch.sum(mask) == 0: | |
trans.append(torch.zeros((2)).float().cuda()) | |
else: | |
trans.append( | |
(-coord_pred[bid, mask, :2] + targets_kp2d[:, smpl_x.joint_part[part_name], :][bid, mask, :2]).mean(0)) | |
trans = torch.stack(trans)[:, None, :] | |
coord_pred = coord_pred + trans # global translation alignment | |
pred_smpl_kp2d = torch.cat( | |
(pred_smpl_kp2d[:, :smpl_x.joint_part[part_name][0], :], coord_pred, | |
pred_smpl_kp2d[:, smpl_x.joint_part[part_name][-1] + 1:, :]), | |
1) | |
vis = False | |
if vis: | |
import mmcv | |
import cv2 | |
import numpy as np | |
from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl | |
from detrsmpl.models.body_models.builder import build_body_model | |
from pytorch3d.io import save_obj | |
from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d | |
img = mmcv.imdenormalize( | |
img=(data_batch['img'][0].cpu().numpy()).transpose(1, 2, 0), | |
mean=np.array([123.675, 116.28, 103.53]), | |
std=np.array([58.395, 57.12, 57.375]), | |
to_bgr=True).astype(np.uint8).copy() | |
device = outputs['pred_smpl_kp3d'].device | |
gt_2d = (coord) | |
img = mmcv.imshow_bboxes(img,bbox[0,None].int().cpu().numpy(),show=False) | |
gt_2d[:,:,0] /= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0])) | |
gt_2d[:,:,1] /= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1])) | |
gt_2d_ori = gt_2d.clone() | |
gt_2d_ori[:,:,0] += (bbox[:, None, 0] / img_shape[:, 1:]) | |
gt_2d_ori[:,:,1] += (bbox[:, None, 1] / img_shape[:, :1]) | |
gt_2d = (gt_2d*img_wh[:, None]).cpu().detach().numpy() | |
gt_2d_ori = (gt_2d_ori*img_wh[:, None]).cpu().detach().numpy() | |
# visualize keypoints after translation to bbox and to gt | |
pred_2d = (coord_pred).clone() | |
pred_2d[:,:,0] /= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0])) | |
pred_2d[:,:,1] /= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1])) | |
# visualize keypoints begore translation to bbox and to gt | |
pred_2d_ori = (coord_pred-trans).clone() | |
pred_2d_ori[:,:,0] /= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0])) | |
pred_2d_ori[:,:,1] /= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1])) | |
pred_2d_ori[:,:,0] += (bbox[:, None, 0] / img_shape[:, 1:]) | |
pred_2d_ori[:,:,1] += (bbox[:, None, 1] / img_shape[:, :1]) | |
pred_2d = (pred_2d*img_wh[:, None]).cpu().detach().numpy() | |
pred_2d_ori = (pred_2d_ori*img_wh[:, None]).cpu().detach().numpy() | |
visualize_kp2d( | |
gt_2d[0].reshape(-1,2)[None], | |
output_path='./figs/gt2d%s'%part_name, | |
image_array=img.copy()[None], | |
# data_source='smplx_137', | |
disable_limbs = True, | |
overwrite=True) | |
visualize_kp2d( | |
gt_2d_ori[0].reshape(-1,2)[None], | |
output_path='./figs/gt2d%s_ori'%part_name, | |
image_array=img.copy()[None], | |
# data_source='smplx_137', | |
disable_limbs = True, | |
overwrite=True) | |
visualize_kp2d( | |
pred_2d[0].reshape(-1,2)[None], | |
output_path='./figs/pred2d%s'%part_name, | |
image_array=img.copy()[None], | |
# data_source='smplx_137', | |
disable_limbs = True, | |
overwrite=True) | |
visualize_kp2d( | |
pred_2d_ori[0].reshape(-1,2)[None], | |
output_path='./figs/pred2d%s_ori'%part_name, | |
image_array=img.copy()[None], | |
# data_source='smplx_137', | |
disable_limbs = True, | |
overwrite=True) | |
loss_smpl_kp2d_ba = F.l1_loss(pred_smpl_kp2d, | |
targets_kp2d[:, :, :2], | |
reduction='none') | |
valid_pos = keypoints2d_conf > 0 | |
losses = {} | |
if keypoints2d_conf[valid_pos].numel() == 0: | |
return { | |
'loss_smpl_body_kp2d_ba': | |
torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0, | |
'loss_smpl_lhand_kp2d_ba': | |
torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0, | |
'loss_smpl_rhand_kp2d_ba': | |
torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0, | |
'loss_smpl_face_kp2d_ba': | |
torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0, | |
} | |
# loss /= targets_kp3d_conf[valid_pos].numel() | |
# 要改 | |
loss_smpl_kp2d_ba = loss_smpl_kp2d_ba * keypoints2d_conf | |
losses['loss_smpl_body_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['body'], :]) / num_boxes | |
if face_hand_kpt: | |
if num_lhand_bbox>0: | |
losses['loss_smpl_lhand_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['lhand'], :]) / num_lhand_bbox | |
else: | |
losses['loss_smpl_lhand_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 | |
if num_rhand_bbox>0: | |
losses['loss_smpl_rhand_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['rhand'], :]) / num_rhand_bbox | |
else: | |
losses['loss_smpl_rhand_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 | |
if num_face_bbox>0: | |
losses['loss_smpl_face_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['face'], :]) / num_face_bbox | |
else: | |
losses['loss_smpl_face_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 | |
else: | |
losses['loss_smpl_lhand_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['lhand'], :]) / num_lhand_bbox | |
losses['loss_smpl_rhand_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['rhand'], :]) / num_rhand_bbox | |
losses['loss_smpl_face_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, | |
smpl_x.joint_part['face'], :]) / num_face_bbox | |
return losses | |
def loss_boxes(self, outputs, targets, indices, | |
idx, num_boxes, data_batch, | |
face_hand_box=False): | |
"""Compute the losses related to the bounding boxes, the L1 regression | |
loss and the GIoU loss targets dicts must contain the key "boxes" | |
containing a tensor of dim [nb_target_boxes, 4] The target boxes are | |
expected in format (center_x, center_y, w, h), normalized by the image | |
size.""" | |
indices = indices[0] | |
device = outputs['pred_logits'].device | |
assert 'pred_boxes' in outputs | |
# assert 'pred_lhand_boxes' in outputs | |
# assert 'pred_rhand_boxes' in outputs | |
# assert 'pred_face_boxes' in outputs | |
src_body_boxes = outputs['pred_boxes'][idx] | |
target_body_boxes = torch.cat( | |
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_body_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['body_bbox_valid'], indices)], dim=0) | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
loss_body_bbox = F.l1_loss(src_body_boxes, target_body_boxes, reduction='none') | |
loss_body_bbox = loss_body_bbox * target_body_boxes_conf[:,None] | |
losses = {} | |
losses['loss_body_bbox'] = loss_body_bbox.sum() / num_boxes | |
loss_body_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_body_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_body_boxes))) | |
loss_body_giou = loss_body_giou * target_body_boxes_conf | |
losses['loss_body_giou'] = loss_body_giou.sum() / num_boxes | |
if 'pred_lhand_boxes' in outputs and face_hand_box: | |
src_lhand_boxes = outputs['pred_lhand_boxes'][idx] | |
target_lhand_boxes = torch.cat( | |
[t['lhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_lhand_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) | |
# print(target_lhand_boxes_conf) | |
loss_lhand_bbox = F.l1_loss(src_lhand_boxes, target_lhand_boxes, reduction='none') | |
loss_lhand_bbox = loss_lhand_bbox * target_lhand_boxes_conf[:,None] | |
losses['loss_lhand_bbox'] = loss_lhand_bbox.sum() / num_boxes | |
loss_lhand_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_lhand_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_lhand_boxes))) | |
loss_lhand_giou = loss_lhand_giou * target_lhand_boxes_conf | |
losses['loss_lhand_giou'] = loss_lhand_giou.sum() / num_boxes | |
# import mmcv | |
# import cv2 | |
# img = (data_batch['img'][0]*255).permute(1,2,0).int().detach().cpu().numpy() | |
# pred_bbox = (box_ops.box_cxcywh_to_xyxy(src_lhand_boxes[0]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) | |
# pred_bbox = (box_ops.box_cxcywh_to_xyxy(src_lhand_boxes[0]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) | |
# img = mmcv.imshow_bboxes(img.copy(), pred_bbox, show=False) | |
# cv2.imwrite('test.png',img) | |
if 'pred_rhand_boxes' in outputs and face_hand_box: | |
src_rhand_boxes = outputs['pred_rhand_boxes'][idx] | |
target_rhand_boxes = torch.cat( | |
[t['rhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_rhand_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices)], dim=0) | |
loss_rhand_bbox = F.l1_loss(src_rhand_boxes, target_rhand_boxes, reduction='none') | |
loss_rhand_bbox = loss_rhand_bbox * target_rhand_boxes_conf[:,None] | |
losses['loss_rhand_bbox'] = loss_rhand_bbox.sum() / num_boxes | |
loss_rhand_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_rhand_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_rhand_boxes))) | |
loss_rhand_giou = loss_rhand_giou * target_rhand_boxes_conf | |
losses['loss_rhand_giou'] = loss_rhand_giou.sum() / num_boxes | |
if 'pred_face_boxes' in outputs and face_hand_box: | |
src_face_boxes = outputs['pred_face_boxes'][idx] | |
target_face_boxes = torch.cat( | |
[t['face_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
target_face_boxes_conf = torch.cat( | |
[t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) | |
loss_face_bbox = F.l1_loss(src_face_boxes, target_face_boxes, reduction='none') | |
loss_face_bbox = loss_face_bbox * target_face_boxes_conf[:,None] | |
losses['loss_face_bbox'] = loss_face_bbox.sum() / num_boxes | |
loss_face_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_face_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_face_boxes))) | |
loss_face_giou = loss_face_giou * target_face_boxes_conf | |
losses['loss_face_giou'] = loss_face_giou.sum() / num_boxes | |
if valid_num == 0: | |
losses = {} | |
if face_hand_box: | |
losses = { | |
'loss_body_bbox': loss_body_bbox.sum() * 0, | |
'loss_body_giou': loss_body_bbox.sum() * 0, | |
'loss_lhand_bbox': loss_lhand_bbox.sum() * 0, | |
'loss_lhand_giou': loss_lhand_bbox.sum() * 0, | |
'loss_rhand_bbox': loss_rhand_bbox.sum() * 0, | |
'loss_rhand_giou': loss_rhand_bbox.sum() * 0, | |
'loss_face_bbox': loss_face_bbox.sum() * 0, | |
'loss_face_giou': loss_face_bbox.sum() * 0, | |
} | |
else: | |
losses = { | |
'loss_body_bbox': loss_body_bbox.sum() * 0, | |
'loss_body_giou': loss_body_bbox.sum() * 0, | |
'loss_lhand_bbox': loss_body_bbox.sum() * 0, | |
'loss_lhand_giou': loss_body_bbox.sum() * 0, | |
'loss_rhand_bbox': loss_body_bbox.sum() * 0, | |
'loss_rhand_giou': loss_body_bbox.sum() * 0, | |
'loss_face_bbox': loss_body_bbox.sum() * 0, | |
'loss_face_giou': loss_body_bbox.sum() * 0, | |
} | |
return losses | |
return losses | |
def loss_dn_boxes(self, outputs, targets, indices, idx, num_boxes, | |
data_batch): | |
""" | |
Input: | |
- src_boxes: bs, num_dn, 4 | |
- tgt_boxes: bs, num_dn, 4 | |
""" | |
indices = indices[0] | |
num_tgt = outputs['num_tgt'] | |
src_boxes = outputs['dn_bbox_pred'] | |
tgt_boxes = outputs['dn_bbox_input'] | |
valid_num=0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
if valid_num == 0: | |
device = outputs['pred_logits'].device | |
losses = { | |
'dn_loss_bbox': src_boxes.sum()*0, | |
'dn_loss_giou': src_boxes.sum()*0, | |
} | |
return losses | |
if 'num_tgt' not in outputs: | |
device = outputs['pred_logits'].device | |
losses = { | |
'dn_loss_bbox': src_boxes.sum()*0, | |
'dn_loss_giou': src_boxes.sum()*0, | |
} | |
return losses | |
if 'num_tgt' not in outputs: | |
device = outputs['pred_logits'].device | |
losses = { | |
'dn_loss_bbox': src_boxes.sum()*0, | |
'dn_loss_giou': src_boxes.sum()*0, | |
} | |
return losses | |
return self.tgt_loss_boxes(src_boxes, tgt_boxes, num_tgt) | |
def loss_dn_labels(self, outputs, targets, indices, idx, num_boxes, | |
data_batch): | |
""" | |
Input: | |
- src_logits: bs, num_dn, num_classes | |
- tgt_labels: bs, num_dn | |
""" | |
indices = indices[0] | |
if 'num_tgt' not in outputs: | |
device = outputs['pred_logits'].device | |
losses = { | |
'dn_loss_ce': outputs['pred_logits'].sum()*0, | |
} | |
return losses | |
valid_num = 0 | |
for indice in indices[0]: | |
valid_num+=len(indice) | |
if valid_num == 0: | |
device = outputs['pred_logits'].device | |
losses = { | |
'dn_loss_ce': outputs['pred_logits'].sum()*0, | |
} | |
return losses | |
num_tgt = outputs['num_tgt'] | |
src_logits = outputs['dn_class_pred'] # bs, num_dn, text_len | |
tgt_labels = outputs['dn_class_input'] | |
return self.tgt_loss_labels(src_logits, tgt_labels, num_tgt) | |
def loss_matching_cost(self, outputs, targets, indices, idx, num_boxes, | |
data_batch): | |
""" | |
Input: | |
- src_logits: bs, num_dn, num_classes | |
- tgt_labels: bs, num_dn | |
""" | |
cost_mean_dict = indices[1] | |
losses = {'set_{}'.format(k): v for k, v in cost_mean_dict.items()} | |
return losses | |
def _get_src_permutation_idx(self, indices): | |
# permute predictions following indices | |
batch_idx = torch.cat( | |
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
src_idx = torch.cat([src for (src, _) in indices]) | |
return batch_idx, src_idx | |
def _get_tgt_permutation_idx(self, indices): | |
# permute targets following indices | |
batch_idx = torch.cat( | |
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
return batch_idx, tgt_idx | |
def get_loss(self, loss, outputs, targets, data_batch, indices, num_boxes, | |
**kwargs): | |
loss_map = { | |
'smpl_pose': self.loss_smpl_pose, | |
'smpl_beta': self.loss_smpl_beta, | |
'smpl_expr': self.loss_smpl_expr, | |
'smpl_kp2d': self.loss_smpl_kp2d, | |
'smpl_kp2d_ba': self.loss_smpl_kp2d_ba, | |
'smpl_kp3d_ra': self.loss_smpl_kp3d_ra, | |
'smpl_kp3d': self.loss_smpl_kp3d, | |
'labels': self.loss_labels, | |
'cardinality': self.loss_cardinality, | |
'boxes': self.loss_boxes, | |
'dn_label': self.loss_dn_labels, | |
'dn_bbox': self.loss_dn_boxes, | |
'matching': self.loss_matching_cost, | |
} | |
idx = self._get_src_permutation_idx(indices[0]) | |
# pdb.set_trace() | |
assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
return loss_map[loss](outputs, targets, indices, idx, num_boxes, | |
data_batch, **kwargs) | |
def prep_for_dn2(self, mask_dict): | |
known_bboxs = mask_dict['known_bboxs'] | |
known_labels = mask_dict['known_labels'] | |
output_known_coord = mask_dict['output_known_coord'] | |
output_known_class = mask_dict['output_known_class'] | |
num_tgt = mask_dict['pad_size'] | |
return known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt | |
## SMPL losses | |
def forward(self, outputs, targets, data_batch, return_indices=False): | |
""" This performs the loss computation. | |
Parameters: | |
outputs: dict of tensors, see the output specification of the model for the format | |
targets: list of dicts, such that len(targets) == batch_size. | |
The expected keys in each dict depends on the losses applied, see each loss' doc | |
return_indices: used for vis. if True, the layer0-5 indices will be returned as well. | |
""" | |
# import pdb; pdb.set_trace() | |
outputs_without_aux = { | |
k: v | |
for k, v in outputs.items() if k != 'aux_outputs' | |
} | |
device = next(iter(outputs.values())).device | |
# Compute the average number of target boxes accross all nodes, for normalization purposes | |
num_boxes = sum(len(t['boxes']) for t in targets) | |
num_boxes = torch.as_tensor([num_boxes], | |
dtype=torch.float, | |
device=device) | |
if is_dist_avail_and_initialized(): | |
torch.distributed.all_reduce(num_boxes) | |
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() | |
# loss for final layer | |
# pdb.set_trace() | |
indices = self.matcher(outputs_without_aux, targets) | |
if return_indices: | |
indices0_copy = indices | |
indices_list = [] | |
losses = {} | |
smpl_loss = ['smpl_pose', 'smpl_beta', 'smpl_expr', 'smpl_kp2d', | |
'smpl_kp2d_ba', 'smpl_kp3d', 'smpl_kp3d_ra'] | |
# import pdb; pdb.set_trace() | |
for loss in self.losses: | |
# print(loss) | |
# print(self.get_loss(loss, outputs, targets, indices, num_boxes)) | |
kwargs = {} | |
if loss == 'keypoints' or loss in smpl_loss: | |
kwargs.update({'face_hand_kpt': True}) | |
if loss == 'boxes': | |
kwargs.update({'face_hand_box': True}) | |
losses.update( | |
self.get_loss( | |
loss, outputs, targets, | |
data_batch, indices, | |
num_boxes, **kwargs | |
)) | |
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
if 'aux_outputs' in outputs: | |
for idx, aux_outputs in enumerate(outputs['aux_outputs']): | |
indices = self.matcher(aux_outputs, targets) | |
if return_indices: | |
indices_list.append(indices) | |
for loss in self.losses: | |
kwargs = {} | |
if loss == 'boxes': | |
kwargs.update({'face_hand_box': False}) | |
if idx >= self.num_box_decoder_layers: | |
kwargs.update({'face_hand_box': True}) | |
if loss == 'masks': | |
continue | |
if loss == 'keypoints': | |
if idx < self.num_box_decoder_layers: | |
continue | |
elif idx < self.num_hand_face_decoder_layers: | |
kwargs.update({'face_hand_kpt': False}) | |
else: | |
kwargs.update({'face_hand_kpt': True}) | |
if loss in smpl_loss: | |
if idx < self.num_box_decoder_layers: | |
continue | |
elif idx < self.num_hand_face_decoder_layers: | |
kwargs.update({'face_hand_kpt': False}) | |
else: | |
kwargs.update({'face_hand_kpt': True}) | |
if loss == 'labels': | |
# Logging is enabled only for the last layer | |
kwargs = {'log': False} | |
# if loss == 'smpl_expr' and idx < self.num_box_decoder_layers: | |
# continue | |
# import pdb;pdb.set_trace() | |
l_dict = self.get_loss(loss, aux_outputs, targets, | |
data_batch, indices, num_boxes, | |
**kwargs) | |
l_dict = {k + f'_{idx}': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
# interm_outputs loss | |
if 'interm_outputs' in outputs: | |
interm_outputs = outputs['interm_outputs'] | |
indices = self.matcher(interm_outputs, targets) | |
if return_indices: | |
indices_list.append(indices) | |
for loss in self.losses: | |
if loss in ['dn_bbox', 'dn_label', 'keypoints']: | |
continue | |
if loss in [ | |
'smpl_pose', 'smpl_beta', 'smpl_kp2d_ba', 'smpl_kp2d', | |
'smpl_kp3d_ra', 'smpl_kp3d', 'smpl_expr' | |
]: | |
continue | |
kwargs = {} | |
if loss == 'labels': | |
kwargs = {'log': False} | |
l_dict = self.get_loss(loss, interm_outputs, targets, | |
data_batch, indices, num_boxes, | |
**kwargs) | |
l_dict = {k + f'_interm': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
# aux_init loss | |
if 'query_expand' in outputs: | |
interm_outputs = outputs['query_expand'] | |
indices = self.matcher(interm_outputs, targets) | |
if return_indices: | |
indices_list.append(indices) | |
for loss in self.losses: | |
if loss in ['dn_bbox', 'dn_label']: | |
continue | |
kwargs = {} | |
if loss == 'labels': | |
kwargs = {'log': False} | |
l_dict = self.get_loss(loss, interm_outputs, targets, | |
data_batch, indices, num_boxes, | |
**kwargs) | |
l_dict = {k + f'_query_expand': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
if return_indices: | |
indices_list.append(indices0_copy) | |
return losses, indices_list | |
return losses | |
def tgt_loss_boxes( | |
self, | |
src_boxes, | |
tgt_boxes, | |
num_tgt, | |
): | |
""" | |
Input: | |
- src_boxes: bs, num_dn, 4 | |
- tgt_boxes: bs, num_dn, 4 | |
""" | |
loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none') | |
losses = {} | |
losses['dn_loss_bbox'] = loss_bbox.sum() / num_tgt | |
loss_giou = 1 - torch.diag( | |
box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_boxes.flatten(0, 1)), | |
box_ops.box_cxcywh_to_xyxy(tgt_boxes.flatten(0, 1)))) | |
losses['dn_loss_giou'] = loss_giou.sum() / num_tgt | |
return losses | |
def tgt_loss_labels(self, | |
src_logits: Tensor, | |
tgt_labels: Tensor, | |
num_tgt: int, | |
log: bool = True): | |
""" | |
Input: | |
- src_logits: bs, num_dn, num_classes | |
- tgt_labels: bs, num_dn | |
""" | |
target_classes_onehot = torch.zeros([ | |
src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 | |
], | |
dtype=src_logits.dtype, | |
layout=src_logits.layout, | |
device=src_logits.device) | |
target_classes_onehot.scatter_(2, tgt_labels.unsqueeze(-1), 1) | |
target_classes_onehot = target_classes_onehot[:, :, :-1] | |
loss_ce = sigmoid_focal_loss(src_logits, | |
target_classes_onehot, | |
num_tgt, | |
alpha=self.focal_alpha, | |
gamma=2) * src_logits.shape[1] | |
losses = {'dn_loss_ce': loss_ce} | |
return losses | |