AiOS / models /aios /aios_smplx copy.py
ttxskk
update
d7e58f0
import copy
import pdb
import os
import math
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from util import box_ops
from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz
from util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy,
get_world_size, interpolate,
is_dist_avail_and_initialized, inverse_sigmoid)
from .backbones import build_backbone
from .matcher import build_matcher
from .transformer import build_transformer
from .utils import PoseProjector, sigmoid_focal_loss, MLP
from .postprocesses import PostProcess_SMPLX, PostProcess_aios
from .postprocesses import PostProcess_SMPLX_Multi as PostProcess_SMPLX
from .postprocesses import PostProcess_SMPLX_Multi_Box
from .postprocesses import PostProcess_SMPLX_Multi_Infer, PostProcess_SMPLX_Multi_Infer_Box
from .criterion_smplx import SetCriterion, SetCriterion_Box
from ..registry import MODULE_BUILD_FUNCS
from detrsmpl.core.conventions.keypoints_mapping import convert_kps
from detrsmpl.models.body_models.builder import build_body_model
from util.human_models import smpl_x
from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idxs_by_part
import numpy as np
import random
from detrsmpl.utils.geometry import (rot6d_to_rotmat)
from detrsmpl.utils.transforms import rotmat_to_aa
import cv2
from config.config import cfg
class AiOSSMPLX(nn.Module):
def __init__(
self,
backbone,
transformer,
num_classes,
num_queries,
aux_loss=False,
iter_update=True,
query_dim=4,
random_refpoints_xy=False,
fix_refpoints_hw=-1,
num_feature_levels=1,
nheads=8,
two_stage_type='no',
dec_pred_class_embed_share=False,
dec_pred_bbox_embed_share=False,
dec_pred_pose_embed_share=False,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_batch_gt_fuse=False,
dn_labelbook_size=100,
dn_attn_mask_type_list=['group2group'],
cls_no_bias=False,
num_group=100,
num_body_points=17,
num_hand_points=10,
num_face_points=10,
num_box_decoder_layers=2,
num_hand_face_decoder_layers=4,
body_model=dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
keypoint_dst='smplx_137',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True),
train=True,
inference=False,
focal_length=[5000., 5000.],
camera_3d_size=2.5
):
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.num_classes = num_classes
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim)
self.num_body_points = num_body_points
self.num_hand_points = num_hand_points
self.num_face_points = num_face_points
self.num_whole_body_points = num_body_points + 2*num_hand_points + num_face_points
self.num_box_decoder_layers = num_box_decoder_layers
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers
self.focal_length = focal_length
self.camera_3d_size=camera_3d_size
self.inference = inference
if train:
self.smpl_convention = 'smplx'
else:
self.smpl_convention = 'h36m'
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
self.random_refpoints_xy = random_refpoints_xy # False
self.fix_refpoints_hw = fix_refpoints_hw # -1
# for dn training
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_batch_gt_fuse = dn_batch_gt_fuse
self.dn_labelbook_size = dn_labelbook_size
self.dn_attn_mask_type_list = dn_attn_mask_type_list
assert all([
i in ['match2dn', 'dn2dn', 'group2group']
for i in dn_attn_mask_type_list
])
assert not dn_batch_gt_fuse
# build human body
# if train:
# self.body_model = build_body_model(body_model)
if inference:
body_model=dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
num_betas=10,
keypoint_dst='smplx',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True)
self.body_model = build_body_model(body_model)
for param in self.body_model.parameters():
param.requires_grad = False
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels) # 3
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
))
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels,
hidden_dim,
kernel_size=3,
stride=2,
padding=1),
nn.GroupNorm(32, hidden_dim),
))
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!'
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1],
hidden_dim,
kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
])
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = box_pred_damping = None
self.iter_update = iter_update
assert iter_update, 'Why not iter_update?'
# prepare pred layers
self.dec_pred_class_embed_share = dec_pred_class_embed_share # false
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # false
# 1.1 prepare class & box embed
_class_embed = nn.Linear(hidden_dim,
num_classes,
bias=(not cls_no_bias))
if not cls_no_bias:
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
_class_embed.bias.data = torch.ones(self.num_classes) * bias_value
# 1.2 box embed layer list
if dec_pred_class_embed_share:
class_embed_layerlist = [
_class_embed for i in range(transformer.num_decoder_layers)
]
else:
class_embed_layerlist = [
copy.deepcopy(_class_embed)
for i in range(transformer.num_decoder_layers)
]
###########################################################################
# body bbox + l/r hand box + face box
###########################################################################
# 1.1 body bbox embed
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
# 1.2 body bbox embed layer list
self.num_group = num_group
if dec_pred_bbox_embed_share:
box_body_embed_layerlist = [
_bbox_embed for i in range(transformer.num_decoder_layers)
]
else:
box_body_embed_layerlist = [
copy.deepcopy(_bbox_embed)
for i in range(transformer.num_decoder_layers)
]
# 2.1 lhand bbox embed
_bbox_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) # TODO: the out shape should be 2 not 4
nn.init.constant_(_bbox_hand_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_hand_embed.layers[-1].bias.data, 0)
_bbox_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_bbox_hand_hw_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_hand_hw_embed.layers[-1].bias.data, 0)
# 2.2 lhand bbox embed layer list
if dec_pred_pose_embed_share:
box_hand_embed_layerlist = \
[_bbox_hand_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)]
else:
box_hand_embed_layerlist = [
copy.deepcopy(_bbox_hand_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers + 1)
]
if dec_pred_pose_embed_share:
box_hand_hw_embed_layerlist = [
_bbox_hand_hw_embed for i in range(
transformer.num_decoder_layers - num_box_decoder_layers)
]
else:
box_hand_hw_embed_layerlist = [
copy.deepcopy(_bbox_hand_hw_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 4.1 face bbox embed
_bbox_face_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_bbox_face_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_face_embed.layers[-1].bias.data, 0)
_bbox_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_bbox_face_hw_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_face_hw_embed.layers[-1].bias.data, 0)
# 4.2 face bbox embed layer list
if dec_pred_pose_embed_share:
box_face_embed_layerlist = [
_bbox_face_embed for i in range(
transformer.num_decoder_layers - num_box_decoder_layers + 1)
]
else:
box_face_embed_layerlist = [
copy.deepcopy(_bbox_face_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers + 1)
]
if dec_pred_pose_embed_share:
box_face_hw_embed_layerlist = [
_bbox_face_hw_embed for i in range(
transformer.num_decoder_layers - num_box_decoder_layers)]
else:
box_face_hw_embed_layerlist = [
copy.deepcopy(_bbox_face_hw_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
###########################################################################
# body kp2d + l/r hand kp2d + face kp2d
###########################################################################
######## body #######
# 1.1 body kp2d embed
_pose_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_pose_embed.layers[-1].weight.data, 0)
nn.init.constant_(_pose_embed.layers[-1].bias.data, 0)
# 1.2 body kp2d embed layer list
if num_body_points == 17:
if dec_pred_pose_embed_share:
pose_embed_layerlist = \
[_pose_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)]
else:
pose_embed_layerlist = [
copy.deepcopy(_pose_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers + 1)
]
else:
if dec_pred_pose_embed_share:
pose_embed_layerlist = [
_pose_embed for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
pose_embed_layerlist = [
copy.deepcopy(_pose_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 1.3 body kp bbox embed
_pose_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3)
# 1.4 body kp bbox embed layer list
pose_hw_embed_layerlist = [
_pose_hw_embed for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
######## lhand #######
# 2.1 lhand kp2d embed
_pose_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_pose_hand_embed.layers[-1].weight.data, 0)
nn.init.constant_(_pose_hand_embed.layers[-1].bias.data, 0)
# 2.2 lhand kp2d embed layer list
if dec_pred_pose_embed_share:
pose_hand_embed_layerlist = \
[_pose_hand_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers+1)]
else:
pose_hand_embed_layerlist = [
copy.deepcopy(_pose_hand_embed)
for i in range(transformer.num_decoder_layers -
num_hand_face_decoder_layers + 1)
]
# 2.3 lhand kp bbox embed
_pose_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3)
# 2.4 lhand kp bbox embed layer list
pose_hand_hw_embed_layerlist = [
_pose_hand_hw_embed for i in range(transformer.num_decoder_layers -
num_hand_face_decoder_layers)
]
######## face #######
# 4.1 face kp2d embed
_pose_face_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_pose_face_embed.layers[-1].weight.data, 0)
nn.init.constant_(_pose_face_embed.layers[-1].bias.data, 0)
# 4.2 face kp2d embed layer list
if dec_pred_pose_embed_share:
pose_face_embed_layerlist = \
[_pose_face_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers+1)]
else:
pose_face_embed_layerlist = [
copy.deepcopy(_pose_face_embed)
for i in range(transformer.num_decoder_layers -
num_hand_face_decoder_layers + 1)
]
# 4.3 face kp bbox embed
_pose_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3)
# 4.4 face kp bbox embed layer list
pose_face_hw_embed_layerlist = [
_pose_face_hw_embed for i in range(transformer.num_decoder_layers -
num_hand_face_decoder_layers)
]
###########################################################################
# smpl pose + betas + kp2d + kp3d + cam
###########################################################################
# 1. smpl pose embed
if body_model['type'].upper()=='SMPL':
self.body_model_joint_num = 24
elif body_model['type'].upper()=='SMPLX':
self.body_model_joint_num = 22
else:
raise ValueError(
f'Only supports SMPL or SMPLX, but get {body_model.type}')
#TODO:
_smpl_pose_embed = MLP(hidden_dim * (self.num_body_points + 4),
hidden_dim, self.body_model_joint_num * 6, 3)
nn.init.constant_(_smpl_pose_embed.layers[-1].weight.data, 0)
nn.init.constant_(_smpl_pose_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smpl_pose_embed_layerlist = [
_smpl_pose_embed
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
smpl_pose_embed_layerlist = [
copy.deepcopy(_smpl_pose_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 2. smpl betas embed
_smpl_beta_embed = MLP(hidden_dim * (self.num_body_points + 4),
hidden_dim, 10, 3)
nn.init.constant_(_smpl_beta_embed.layers[-1].weight.data, 0)
nn.init.constant_(_smpl_beta_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smpl_beta_embed_layerlist = [
_smpl_beta_embed
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
smpl_beta_embed_layerlist = [
copy.deepcopy(_smpl_beta_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 3. smpl cam embed
_cam_embed = MLP(hidden_dim * (self.num_body_points + 4), hidden_dim,
3, 3)
nn.init.constant_(_cam_embed.layers[-1].weight.data, 0)
nn.init.constant_(_cam_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
cam_embed_layerlist = [
_cam_embed for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
cam_embed_layerlist = [
copy.deepcopy(_cam_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
###########################################################################
# smplx body pose + hand pose + expression + betas + kp2d + kp3d + cam
###########################################################################
# 1. smplx body pose embed
# _smplx_pose_embed = MLP(hidden_dim * (self.num_body_points + 1),
# hidden_dim, 23 * 6, 3)
# nn.init.constant_(_smplx_pose_embed.layers[-1].weight.data, 0)
# nn.init.constant_(_smplx_pose_embed.layers[-1].bias.data, 0)
# if dec_pred_bbox_embed_share:
# smplx_pose_embed_layerlist = [
# _smplx_pose_embed
# for i in range(transformer.num_decoder_layers -
# num_box_decoder_layers + 1)
# ]
# else:
# smplx_pose_embed_layerlist = [
# copy.deepcopy(_smplx_pose_embed)
# for i in range(transformer.num_decoder_layers -
# num_box_decoder_layers + 1)
# ]
# 2. smplx hand pose embed
_smplx_hand_pose_embed_layer_2_3 = \
MLP(hidden_dim, hidden_dim, 15 * 6, 3)
nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].bias.data, 0)
_smplx_hand_pose_embed_layer_4_5 = \
MLP(hidden_dim * (self.num_hand_points + 3), hidden_dim, 15 * 6, 3)
nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smplx_hand_pose_embed_layerlist = [
_smplx_hand_pose_embed_layer_2_3
if i<2 else _smplx_hand_pose_embed_layer_4_5
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
smplx_hand_pose_embed_layerlist = [
copy.deepcopy(_smplx_hand_pose_embed_layer_2_3)
if i<2 else copy.deepcopy(_smplx_hand_pose_embed_layer_4_5)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 3. smplx face expression
_smplx_expression_embed_layer_2_3 = \
MLP(hidden_dim, hidden_dim, 10, 3)
nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].bias.data, 0)
_smplx_expression_embed_layer_4_5 = \
MLP(hidden_dim * (self.num_hand_points + 2), hidden_dim, 10, 3)
nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smplx_expression_embed_layerlist = [
_smplx_expression_embed_layer_2_3
if i<2 else _smplx_expression_embed_layer_4_5
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
smplx_expression_embed_layerlist = [
copy.deepcopy(_smplx_expression_embed_layer_2_3)
if i<2 else copy.deepcopy(_smplx_expression_embed_layer_4_5)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# _smplx_expression_embed = MLP(hidden_dim * (self.num_face_points + 2),
# hidden_dim, 10, 3)
# nn.init.constant_(_smplx_expression_embed.layers[-1].weight.data, 0)
# nn.init.constant_(_smplx_expression_embed.layers[-1].bias.data, 0)
# if dec_pred_bbox_embed_share:
# smplx_expression_embed_layerlist = [
# _smplx_expression_embed
# for i in range(transformer.num_decoder_layers -
# num_hand_face_decoder_layers)
# ]
# else:
# smplx_expression_embed_layerlist = [
# copy.deepcopy(_smplx_expression_embed)
# for i in range(transformer.num_decoder_layers -
# num_hand_face_decoder_layers)
# ]
# 4. smplx jaw pose embed
_smplx_jaw_embed_2_3 = MLP(hidden_dim * 1,
hidden_dim, 6, 3)
nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].bias.data, 0)
_smplx_jaw_embed_4_5 = MLP(hidden_dim * (self.num_face_points + 2),
hidden_dim, 6, 3)
nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smplx_jaw_embed_layerlist = [
_smplx_jaw_embed_2_3 if i<2 else _smplx_jaw_embed_4_5
for i in range(
transformer.num_decoder_layers - num_box_decoder_layers)
]
else:
smplx_jaw_embed_layerlist = [
copy.deepcopy(_smplx_jaw_embed_2_3)
if i<2 else copy.deepcopy(_smplx_jaw_embed_4_5)
for i in range(
transformer.num_decoder_layers - num_box_decoder_layers)
]
###############
self.bbox_embed = nn.ModuleList(box_body_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.pose_embed = nn.ModuleList(pose_embed_layerlist)
self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.pose_embed = self.pose_embed
self.transformer.decoder.pose_hw_embed = self.pose_hw_embed
self.transformer.decoder.class_embed = self.class_embed
# smpl
self.smpl_pose_embed = nn.ModuleList(smpl_pose_embed_layerlist)
self.smpl_beta_embed = nn.ModuleList(smpl_beta_embed_layerlist)
self.smpl_cam_embed = nn.ModuleList(cam_embed_layerlist)
# self.smpl_cam_f_embed = nn.ModuleList(f_embed_layerlist)
# self.transformer.decoder.smpl_pose_embed = self.smpl_pose_embed
# self.transformer.decoder.smpl_beta_embed = self.smpl_beta_embed
# self.transformer.decoder.smpl_cam_embed = self.smpl_cam_embed
# smplx lhand kp
self.bbox_hand_embed = nn.ModuleList(box_hand_embed_layerlist)
self.bbox_hand_hw_embed = nn.ModuleList(box_hand_hw_embed_layerlist)
self.pose_hand_embed = nn.ModuleList(pose_hand_embed_layerlist)
self.pose_hand_hw_embed = nn.ModuleList(pose_hand_hw_embed_layerlist)
self.transformer.decoder.bbox_hand_embed = self.bbox_hand_embed
self.transformer.decoder.bbox_hand_hw_embed = self.bbox_hand_hw_embed
self.transformer.decoder.pose_hand_embed = self.pose_hand_embed
self.transformer.decoder.pose_hand_hw_embed = self.pose_hand_hw_embed
# smplx face kp
self.bbox_face_embed = nn.ModuleList(box_face_embed_layerlist)
self.bbox_face_hw_embed = nn.ModuleList(box_face_hw_embed_layerlist)
self.pose_face_embed = nn.ModuleList(pose_face_embed_layerlist)
self.pose_face_hw_embed = nn.ModuleList(pose_face_hw_embed_layerlist)
self.transformer.decoder.bbox_face_embed = self.bbox_face_embed
self.transformer.decoder.bbox_face_hw_embed = self.bbox_face_hw_embed
self.transformer.decoder.pose_face_embed = self.pose_face_embed
self.transformer.decoder.pose_face_hw_embed = self.pose_face_hw_embed
# smplx
self.smpl_hand_pose_embed = nn.ModuleList(smplx_hand_pose_embed_layerlist)
# self.smplx_rhand_pose_embed = nn.ModuleList(smplx_rhand_pose_embed_layerlist)
self.smpl_expr_embed = nn.ModuleList(smplx_expression_embed_layerlist)
self.smpl_jaw_embed = nn.ModuleList(smplx_jaw_embed_layerlist)
# self.transformer.decoder.smplx_hand_pose_embed = self.smplx_hand_pose_embed
# self.transformer.decoder.smplx_rhand_pose_embed = self.smplx_rhand_pose_embed
# self.transformer.decoder.num_whole_bosmpl_expr_embeddy_points = self.smplx_expression_embed
# self.transformer.decoder.smpl_jaw_embed = self.smplx_jaw_embed
#########
self.transformer.decoder.num_hand_face_decoder_layers = num_hand_face_decoder_layers
self.transformer.decoder.num_box_decoder_layers = num_box_decoder_layers
self.transformer.decoder.num_body_points = num_body_points
self.transformer.decoder.num_hand_points = num_hand_points
self.transformer.decoder.num_face_points = num_face_points
# two stage
self.two_stage_type = two_stage_type
assert two_stage_type in [
'no', 'standard'
], 'unknown param {} of two_stage_type'.format(two_stage_type)
if two_stage_type != 'no':
if two_stage_bbox_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(
_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(
_class_embed)
self.refpoint_embed = None
self._reset_parameters()
def get_camera_trans(self, cam_param, input_body_shape):
# camera translation
t_xy = cam_param[:, :2]
gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive
k_value = torch.FloatTensor(
[
math.sqrt(
self.focal_length[0] * self.focal_length[1] * self.camera_3d_size * self.camera_3d_size /
(input_body_shape[0] * input_body_shape[1])
)
]
).cuda().view(-1)
t_z = k_value * gamma
cam_trans = torch.cat((t_xy, t_z[:, None]), 1)
return cam_trans
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def prepare_for_dn2(self, targets):
if not self.training:
device = targets[0]['boxes'].device
bs = len(targets)
num_points = self.num_body_points + 4
attn_mask2 = torch.zeros(
bs,
self.nheads,
self.num_group * num_points,
self.num_group * num_points,
device=device,
dtype=torch.bool)
group_bbox_kpt = num_points
group_nobbox_kpt = self.num_body_points
kpt_index = [
x for x in range(self.num_group * num_points)
if x % num_points in [
0,
self.num_body_points+1,
self.num_body_points+2,
self.num_body_points+3
]
]
for matchj in range(self.num_group * num_points):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
if sj > 0:
attn_mask2[:, :, matchj, :sj] = True
if ej < self.num_group * num_points:
attn_mask2[:, :, matchj, ej:] = True
for match_x in range(self.num_group * num_points):
if match_x % group_bbox_kpt in [0,
self.num_body_points+1,
self.num_body_points+2,
self.num_body_points+3]:
attn_mask2[:,:,match_x,kpt_index]=False
num_points = self.num_whole_body_points + 4
attn_mask3 = torch.zeros(
bs,
self.nheads,
self.num_group * (num_points),
self.num_group * (num_points),
device=device,
dtype=torch.bool)
group_bbox_kpt = (num_points)
# group_nobbox_kpt = self.num_body_points
kpt_index = [
x for x in range(self.num_group * (num_points)) if x % (num_points) in
[0,
1+self.num_body_points,
2+self.num_body_points+self.num_hand_points,
3+self.num_body_points+self.num_hand_points*2
]
]
for matchj in range(self.num_group * num_points):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
if sj > 0:
attn_mask3[:, :, matchj, :sj] = True
if ej < self.num_group * num_points:
attn_mask3[:, :, matchj, ej:] = True
for match_x in range(self.num_group * num_points):
if match_x % group_bbox_kpt in [
0,
1 + self.num_body_points,
2 + self.num_body_points + self.num_hand_points,
3 + self.num_body_points + self.num_hand_points * 2]:
attn_mask3[:, :, match_x, kpt_index] = False
# num_points = self.num_whole_body_points + 4
# device = targets[0]['boxes'].device
# bs = len(targets)
# attn_mask_infere = torch.zeros(
# bs,
# self.nheads,
# self.num_group * num_points,
# self.num_group * num_points,
# device=device,
# dtype=torch.bool)
# group_bbox_kpt = num_points
# group_nobbox_kpt = self.num_body_points
# kpt_index = [
# x for x in range(self.num_group * num_points)
# if x % num_points == 0
# ]
# for matchj in range(self.num_group * num_points):
# sj = (matchj // group_bbox_kpt) * group_bbox_kpt
# ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt
# if sj > 0:
# attn_mask_infere[:, :, matchj, :sj] = True
# if ej < self.num_group * num_points:
# attn_mask_infere[:, :, matchj, ej:] = True
# for match_x in range(self.num_group * num_points):
# if match_x % group_bbox_kpt == 0:
# attn_mask_infere[:, :, match_x, kpt_index] = False
# attn_mask_infere = attn_mask_infere.flatten(0, 1)
attn_mask2 = attn_mask2.flatten(0, 1)
attn_mask3 = attn_mask3.flatten(0, 1)
return None, None, None, attn_mask2, attn_mask3, None
# targets, dn_scalar, noise_scale = dn_args
device = targets[0]['boxes'].device
bs = len(targets)
dn_number = self.dn_number # 100
dn_box_noise_scale = self.dn_box_noise_scale # 0.4
dn_label_noise_ratio = self.dn_label_noise_ratio # 0.5
# gather gt boxes and labels
gt_boxes = [t['boxes'] for t in targets]
gt_labels = [t['labels'] for t in targets]
gt_keypoints = [t['keypoints'] for t in targets]
# repeat them
def get_indices_for_repeat(now_num, target_num, device='cuda'):
"""
Input:
- now_num: int
- target_num: int
Output:
- indices: tensor[target_num]
"""
out_indice = []
base_indice = torch.arange(now_num).to(device)
multiplier = target_num // now_num
out_indice.append(base_indice.repeat(multiplier))
residue = target_num % now_num
out_indice.append(base_indice[torch.randint(0,
now_num, (residue, ),
device=device)])
return torch.cat(out_indice)
if self.dn_batch_gt_fuse:
raise NotImplementedError
gt_boxes_bsall = torch.cat(gt_boxes) # num_boxes, 4
gt_labels_bsall = torch.cat(gt_labels)
num_gt_bsall = gt_boxes_bsall.shape[0]
if num_gt_bsall > 0:
indices = get_indices_for_repeat(num_gt_bsall, dn_number,
device)
gt_boxes_expand = gt_boxes_bsall[indices][None].repeat(
bs, 1, 1) # bs, num_dn, 4
gt_labels_expand = gt_labels_bsall[indices][None].repeat(
bs, 1) # bs, num_dn
else:
# all negative samples when no gt boxes
gt_boxes_expand = torch.rand(bs, dn_number, 4, device=device)
gt_labels_expand = torch.ones(
bs, dn_number, dtype=torch.int64, device=device) * int(
self.num_classes)
else:
gt_boxes_expand = []
gt_labels_expand = []
gt_keypoints_expand = [] # here
for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate(
zip(gt_boxes, gt_labels, gt_keypoints)): # idx -> batch id
num_gt_i = gt_boxes_i.shape[0] # instance num
if num_gt_i > 0:
indices = get_indices_for_repeat(num_gt_i, dn_number,
device)
gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4
gt_labels_expand_i = gt_labels_i[indices] # add smpl
gt_keypoints_expand_i = gt_keypoint_i[indices]
else:
# all negative samples when no gt boxes
gt_boxes_expand_i = torch.rand(dn_number, 4, device=device)
gt_labels_expand_i = torch.ones(
dn_number, dtype=torch.int64, device=device) * int(
self.num_classes)
gt_keypoints_expand_i = torch.rand(dn_number,
self.num_body_points *
3,
device=device)
gt_boxes_expand.append(gt_boxes_expand_i) # add smpl
gt_labels_expand.append(gt_labels_expand_i)
gt_keypoints_expand.append(gt_keypoints_expand_i)
gt_boxes_expand = torch.stack(gt_boxes_expand)
gt_labels_expand = torch.stack(gt_labels_expand)
gt_keypoints_expand = torch.stack(gt_keypoints_expand)
knwon_boxes_expand = gt_boxes_expand.clone()
knwon_labels_expand = gt_labels_expand.clone()
# add noise
if dn_label_noise_ratio > 0:
prob = torch.rand_like(knwon_labels_expand.float())
chosen_indice = prob < dn_label_noise_ratio
new_label = torch.randint_like(
knwon_labels_expand[chosen_indice], 0,
self.dn_labelbook_size) # randomly put a new one here
knwon_labels_expand[chosen_indice] = new_label
if dn_box_noise_scale > 0:
diff = torch.zeros_like(knwon_boxes_expand)
diff[..., :2] = knwon_boxes_expand[..., 2:] / 2
diff[..., 2:] = knwon_boxes_expand[..., 2:]
knwon_boxes_expand += torch.mul(
(torch.rand_like(knwon_boxes_expand) * 2 - 1.0),
diff) * dn_box_noise_scale
knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0)
input_query_label = self.label_enc(knwon_labels_expand)
input_query_bbox = inverse_sigmoid(knwon_boxes_expand)
# prepare mask
body_mask, body_kps_mask, lhand_mask, lhand_kps_mask, rhand_mask, \
rhand_kps_mask, face_mask, face_kps_mask = \
False, False, False, False, False, False, False, False
if random.random() < 0.2:
body_mask = True
if random.random() < 0.5:
body_kps_mask = True
if random.random() < 0.2:
lhand_mask = True
if random.random() < 0.5:
lhand_kps_mask = True
if random.random() < 0.2:
rhand_mask = True
if random.random() < 0.5:
rhand_kps_mask = True
if random.random() < 0.2:
face_mask = True
if random.random() < 0.5:
face_kps_mask = True
if 'group2group' in self.dn_attn_mask_type_list:
attn_mask = torch.zeros(bs,
self.nheads,
dn_number + self.num_queries,
dn_number + self.num_queries,
device=device,
dtype=torch.bool)
attn_mask[:, :, dn_number:, :dn_number] = True
for idx, (gt_boxes_i,
gt_labels_i) in enumerate(zip(gt_boxes,
gt_labels)): # for batch
num_gt_i = gt_boxes_i.shape[0]
if num_gt_i == 0:
continue
for matchi in range(dn_number):
si = (matchi // num_gt_i) * num_gt_i
ei = (matchi // num_gt_i + 1) * num_gt_i
if si > 0:
attn_mask[idx, :, matchi, :si] = True
if ei < dn_number:
attn_mask[idx, :, matchi, ei:dn_number] = True
attn_mask = attn_mask.flatten(0, 1)
if 'group2group' in self.dn_attn_mask_type_list:
# self.num_body_points = self.num_body_points +3
inter_body_mask = []
if body_mask:
inter_body_mask.append(0)
if body_kps_mask:
indices = sorted(random.sample(range(1, self.num_body_points+1), k=6))
inter_body_mask.extend(indices)
if lhand_mask:
inter_body_mask.append(self.num_body_points+1)
if rhand_mask:
inter_body_mask.append(self.num_body_points+2)
if face_mask:
inter_body_mask.append(self.num_body_points+3)
num_points = self.num_body_points + 4
attn_mask2 = torch.zeros(
bs,
self.nheads,
dn_number + self.num_group * num_points,
dn_number + self.num_group * num_points,
device=device,
dtype=torch.bool)
attn_mask2[:, :, dn_number:, :dn_number] = True
group_bbox_kpt = num_points
# group_nobbox_kpt = self.num_body_points
kpt_index = [x for x in range(self.num_group * num_points)
if x % num_points in [
0, self.num_body_points+1, self.num_body_points+2, self.num_body_points+3]]
for matchj in range(self.num_group * num_points):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
if sj > 0:
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True
if ej < self.num_group * num_points:
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True
if (matchj // group_bbox_kpt) == 0:
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, sj:ej][..., inter_body_mask] = True
for match_x in range(self.num_group * num_points):
if match_x % group_bbox_kpt == 0 and body_mask != False:
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, ::num_points]=False
if match_x % group_bbox_kpt == self.num_body_points + 1 and lhand_mask != False:
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 1::num_points]=False
if match_x % group_bbox_kpt == self.num_body_points + 2 and rhand_mask != False:
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 2::num_points]=False
if match_x % group_bbox_kpt == self.num_body_points + 3 and face_mask != False:
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 3::num_points]=False
# if match_x % group_bbox_kpt in [0,
# self.num_body_points+1,
# self.num_body_points+2,
# self.num_body_points+3]:
# attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False
for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)):
num_gt_i = gt_boxes_i.shape[0]
if num_gt_i == 0:
continue
for matchi in range(dn_number):
si = (matchi // num_gt_i) * num_gt_i
ei = (matchi // num_gt_i + 1) * num_gt_i
if si > 0:
attn_mask2[idx, :, matchi, :si] = True
if ei < dn_number:
attn_mask2[idx, :, matchi, ei:dn_number] = True
attn_mask2 = attn_mask2.flatten(0, 1)
if 'group2group' in self.dn_attn_mask_type_list:
inter_body_mask = []
if body_mask:
inter_body_mask.append(0)
if body_kps_mask:
indices = sorted(random.sample(range(1, self.num_body_points+1), k=6))
inter_body_mask.extend(indices)
if lhand_mask:
inter_body_mask.append(self.num_body_points+1)
if lhand_kps_mask:
indices = sorted(random.sample(range(self.num_body_points+2, self.num_body_points+8), k=3))
inter_body_mask.extend(indices)
if rhand_mask:
inter_body_mask.append(self.num_body_points+8)
if rhand_kps_mask:
indices = sorted(random.sample(range(self.num_body_points+9, self.num_body_points+15), k=3))
inter_body_mask.extend(indices)
if face_mask:
inter_body_mask.append(self.num_body_points+15)
if face_kps_mask:
indices = sorted(random.sample(range(self.num_body_points+16, self.num_body_points+22), k=3) )
inter_body_mask.extend(indices)
# self.num_body_points = self.num_body_points +3
num_points = self.num_whole_body_points + 4
attn_mask3 = torch.zeros(
bs,
self.nheads,
dn_number + self.num_group * (num_points), dn_number + self.num_group * (num_points),
device=device, dtype=torch.bool)
attn_mask3[:, :, dn_number:, :dn_number] = True
group_bbox_kpt = (num_points)
# group_nobbox_kpt = self.num_body_points
kpt_index = [
x for x in range(self.num_group * (num_points)) if x % (num_points) in
[0,
1+self.num_body_points,
2+self.num_body_points+self.num_hand_points,
3+self.num_body_points+self.num_hand_points*2
]
]
for matchj in range(self.num_group * num_points):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
if sj > 0:
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True
if ej < self.num_group * num_points:
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True
if (matchj // group_bbox_kpt) == 0:
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, sj:ej][..., inter_body_mask] = True
for match_x in range(self.num_group * num_points):
if match_x % group_bbox_kpt == 0 and body_mask != False:
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, ::num_points]=False
if match_x % group_bbox_kpt == 1 + self.num_body_points and lhand_mask != False:
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 1::num_points]=False
if match_x % group_bbox_kpt == 2 + self.num_body_points + self.num_hand_points and rhand_mask != False:
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 2::num_points]=False
if match_x % group_bbox_kpt == 3 + self.num_body_points + self.num_hand_points * 2 and face_mask != False:
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 3::num_points]=False
# if match_x % group_bbox_kpt in [0,
# 1 + self.num_body_points,
# 2 + self.num_body_points + self.num_hand_points,
# 3 + self.num_body_points + self.num_hand_points * 2]:
# attn_mask3[:, :, dn_number:, dn_number:][:,:,match_x,kpt_index]=False
for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)):
num_gt_i = gt_boxes_i.shape[0]
if num_gt_i == 0:
continue
for matchi in range(dn_number):
si = (matchi // num_gt_i) * num_gt_i
ei = (matchi // num_gt_i + 1) * num_gt_i
if si > 0:
attn_mask3[idx, :, matchi, :si] = True
if ei < dn_number:
attn_mask3[idx, :, matchi, ei:dn_number] = True
attn_mask3 = attn_mask3.flatten(0, 1)
mask_dict = {
'pad_size': dn_number,
'known_bboxs': gt_boxes_expand,
'known_labels': gt_labels_expand,
'known_keypoints': gt_keypoints_expand
}
return input_query_label, input_query_bbox, attn_mask, attn_mask2, attn_mask3, mask_dict
def dn_post_process2(self, outputs_class, outputs_coord,
outputs_body_keypoints_list, mask_dict):
if mask_dict and mask_dict['pad_size'] > 0:
output_known_class = [
outputs_class_i[:, :mask_dict['pad_size'], :]
for outputs_class_i in outputs_class
]
output_known_coord = [
outputs_coord_i[:, :mask_dict['pad_size'], :]
for outputs_coord_i in outputs_coord
]
outputs_class = [
outputs_class_i[:, mask_dict['pad_size']:, :]
for outputs_class_i in outputs_class
]
outputs_coord = [
outputs_coord_i[:, mask_dict['pad_size']:, :]
for outputs_coord_i in outputs_coord
]
outputs_keypoint = outputs_body_keypoints_list
mask_dict.update({
'output_known_coord': output_known_coord,
'output_known_class': output_known_class
})
return outputs_class, outputs_coord, outputs_keypoint
def forward(self, data_batch: NestedTensor, targets: List = None):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if isinstance(data_batch, dict):
samples, targets = self.prepare_targets(data_batch)
# import pdb; pdb.set_trace()
elif isinstance(data_batch, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(data_batch)
else:
samples = data_batch
# print(samples.data['img'].shape)
# exit()
features, poss = self.backbone(samples)
srcs = []
masks = []
for l, feat in enumerate(features): # len(features=3)
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(),
size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
if self.dn_number > 0 or targets is not None:
input_query_label, input_query_bbox, attn_mask,attn_mask2, attn_mask3, mask_dict =\
self.prepare_for_dn2(targets)
else:
assert targets is None
input_query_bbox = input_query_label = attn_mask = attn_mask2 = attn_mask3 = mask_dict = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask,
attn_mask2, attn_mask3)
# update human boxes
effective_dn_number = self.dn_number if self.training else 0
outputs_body_bbox_list = []
outputs_class = []
for dec_lid, (layer_ref_sig, layer_body_bbox_embed, layer_cls_embed,
layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed,
self.class_embed, hs)):
if dec_lid < self.num_box_decoder_layers:
# human det
layer_delta_unsig = layer_body_bbox_embed(layer_hs)
layer_body_box_outputs_unsig = \
layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_body_box_outputs_unsig = layer_body_box_outputs_unsig.sigmoid()
layer_cls = layer_cls_embed(layer_hs)
# import mmcv
# import cv2
# img = (data_batch['img'][0]*255).permute(1,2,0).int().detach().cpu().numpy()
# bbox = (box_ops.box_cxcywh_to_xyxy(layer_body_box_outputs_unsig[0][0]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4)
# img = mmcv.imshow_bboxes(img.copy(), bbox, show=False)
# cv2.imwrite('test.png',img)
outputs_body_bbox_list.append(layer_body_box_outputs_unsig)
outputs_class.append(layer_cls)
elif dec_lid < self.num_box_decoder_layers + 2:
bs = layer_ref_sig.shape[0]
# dn body bbox
layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query
reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query
layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn)
layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid(
reference_before_sigmoid_body_bbox_dn)
layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid()
# norm body bbox
layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][
:, 0::(self.num_body_points + 4), :] # norm content query
reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][
:, 0::(self.num_body_points+ 4), :] # norm position query
layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm)
layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid(
reference_before_sigmoid_body_bbox_norm)
layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid()
layer_body_box_outputs_unsig = torch.cat(
(layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1)
# classfication
layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn)
layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm)
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1)
outputs_class.append(layer_cls)
outputs_body_bbox_list.append(layer_body_box_outputs_unsig)
else:
bs = layer_ref_sig.shape[0]
# dn body bbox
layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query
reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query
layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn)
layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid(
reference_before_sigmoid_body_bbox_dn)
layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid()
# norm body bbox
layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][
:, 0::(self.num_whole_body_points + 4), :] # norm content query
reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:,effective_dn_number:, :][
:, 0::(self.num_whole_body_points + 4), :] # norm position query
layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm)
layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid(
reference_before_sigmoid_body_bbox_norm)
layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid()
layer_body_box_outputs_unsig = torch.cat(
(layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1)
# classfication
layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn)
layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm)
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1)
outputs_class.append(layer_cls)
outputs_body_bbox_list.append(layer_body_box_outputs_unsig)
# 找query
q_index = torch.topk(layer_cls_norm.max(-1)[0], 100, dim=1)[1]
q_value = torch.topk(layer_cls_norm.max(-1)[0], 100, dim=1)[0]
# update hand and face boxes
outputs_lhand_bbox_list = []
outputs_rhand_bbox_list = []
outputs_face_bbox_list = []
# update keypoints boxes
outputs_body_keypoints_list = []
outputs_body_keypoints_hw = []
outputs_lhand_keypoints_list = []
outputs_lhand_keypoints_hw = []
outputs_rhand_keypoints_list = []
outputs_rhand_keypoints_hw = []
outputs_face_keypoints_list = []
outputs_face_keypoints_hw = []
outputs_smpl_pose_list = []
outputs_smpl_lhand_pose_list = []
outputs_smpl_rhand_pose_list = []
outputs_smpl_expr_list = []
outputs_smpl_jaw_pose_list = []
outputs_smpl_beta_list = []
outputs_smpl_cam_list = []
# outputs_smpl_cam_f_list = []
outputs_smpl_kp2d_list = []
outputs_smpl_kp3d_list = []
outputs_smpl_verts_list = []
body_kpt_index = [
x for x in range(self.num_group * (self.num_body_points + 4))
if x % (self.num_body_points + 4) in range(1,self.num_body_points+1)
]
body_kpt_index_2 = [
x for x in range(self.num_group * (self.num_whole_body_points + 4))
if (x % (self.num_whole_body_points + 4) in range(1,self.num_body_points+1))
]
lhand_bbox_index = [
x for x in range(self.num_group * (self.num_body_points + 4))
if x % (self.num_body_points + 4) != 1
]
lhand_kpt_index = [
x for x in range(self.num_group * (self.num_whole_body_points + 4))
if (x % (self.num_whole_body_points + 4) in range(
self.num_body_points+2,
self.num_body_points+self.num_hand_points+2))]
rhand_bbox_index = [
x for x in range(self.num_group * (self.num_body_points + 4))
if x % (self.num_body_points + 4) != 2
]
rhand_kpt_index = [
x for x in range(self.num_group * (self.num_whole_body_points + 4))
if (x % (self.num_whole_body_points + 4) in range(
self.num_body_points+self.num_hand_points+3,
self.num_body_points+self.num_hand_points*2+3))
]
face_bbox_index = [
x for x in range(self.num_group * (self.num_body_points + 4))
if x % (self.num_body_points + 4) != 3
]
face_kpt_index = [
x for x in range(self.num_group * (self.num_whole_body_points + 4))
if (x % (self.num_whole_body_points + 4) in range(
self.num_body_points+self.num_hand_points*2+4,
self.num_body_points+self.num_hand_points*2+self.num_face_points+4))
]
# smpl pose
# body box, kps, lhand box
body_index = list(range(0,self.num_body_points+2))
# rhand box and face box
body_index.extend(
[self.num_body_points + self.num_hand_points + 2, self.num_body_points + 2 * self.num_hand_points + 3]
)
smpl_pose_index = [
x for x in range(self.num_group * (self.num_whole_body_points + 4))
if (x % (self.num_whole_body_points + 4) in body_index)
]
# smpl lhand
lhand_index = list(range(self.num_body_points+1, self.num_body_points+self.num_hand_points+3))
# body box
lhand_index.insert(0, 0)
smpl_lhand_pose_index = [
x for x in range(self.num_group * (self.num_whole_body_points + 4))
if (x % (self.num_whole_body_points + 4) in lhand_index)]
# smpl rhand
rhand_index = list(range(self.num_body_points + self.num_hand_points + 2, self.num_body_points + self.num_hand_points * 2 +3))
rhand_index.insert(0,self.num_body_points+1)
rhand_index.insert(0,0)
smpl_rhand_pose_index = [
x for x in range(self.num_group * (self.num_whole_body_points + 4))
if (x % (self.num_whole_body_points + 4) in rhand_index)]
# smpl face
face_index = list(range(self.num_body_points + self.num_hand_points * 2 + 3, self.num_body_points + self.num_hand_points * 2 + self.num_face_points + 4))
face_index.insert(0,0)
smpl_face_pose_index = [
x for x in range(self.num_group * (self.num_whole_body_points + 4))
if (x % (self.num_whole_body_points + 4) in face_index)]
for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)):
if dec_lid < self.num_box_decoder_layers:
assert isinstance(layer_hs, torch.Tensor)
bs = layer_hs.shape[0]
layer_body_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_body_points * 3)) # [-, 900, 42]
outputs_body_keypoints_list.append(layer_body_kps_res)
# lhand
layer_lhand_bbox_res = layer_hs.new_zeros(
(bs, self.num_queries, 4)) # [-, 900, 42]
outputs_lhand_bbox_list.append(layer_lhand_bbox_res)
layer_lhand_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_hand_points * 3)) # [-, 900, 42]
outputs_lhand_keypoints_list.append(layer_lhand_kps_res)
# rhand
layer_rhand_bbox_res = layer_hs.new_zeros(
(bs, self.num_queries, 4)) # [-, 900, 42]
outputs_rhand_bbox_list.append(layer_rhand_bbox_res)
layer_rhand_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_hand_points * 3)) # [-, 900, 42]
outputs_rhand_keypoints_list.append(layer_rhand_kps_res)
# face
layer_face_bbox_res = layer_hs.new_zeros(
(bs, self.num_queries, 4)) # [-, 900, 42]
outputs_face_bbox_list.append(layer_face_bbox_res)
layer_face_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_face_points * 3)) # [-, 900, 42]
outputs_face_keypoints_list.append(layer_face_kps_res)
# smpl or smplx
smpl_pose = layer_hs.new_zeros((bs, self.num_queries, self.body_model_joint_num * 3))
smpl_rhand_pose = layer_hs.new_zeros(
(bs, self.num_queries, 15 * 3))
smpl_lhand_pose = layer_hs.new_zeros(
(bs, self.num_queries, 15 * 3))
smpl_expr = layer_hs.new_zeros((bs, self.num_queries, 10))
smpl_jaw_pose = layer_hs.new_zeros((bs, self.num_queries, 6))
smpl_beta = layer_hs.new_zeros((bs, self.num_queries, 10))
smpl_cam = layer_hs.new_zeros((bs, self.num_queries, 3))
# smpl_cam_f = layer_hs.new_zeros((bs, self.num_queries, 1))
# smpl_kp2d = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points,3))
smpl_kp3d = layer_hs.new_zeros(
(bs, self.num_queries, self.num_body_points, 4))
outputs_smpl_pose_list.append(smpl_pose)
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose)
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose)
outputs_smpl_expr_list.append(smpl_expr)
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose)
outputs_smpl_beta_list.append(smpl_beta)
outputs_smpl_cam_list.append(smpl_cam)
# outputs_smpl_cam_f_list.append(smpl_cam_f)
# outputs_smpl_kp2d_list.append(smpl_kp2d)
outputs_smpl_kp3d_list.append(smpl_kp3d)
elif dec_lid < self.num_box_decoder_layers +2:
bs = layer_ref_sig.shape[0]
layer_hs_body_kpt = \
layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(body_kpt_index, device=layer_hs.device))
# body kp2d
delta_body_kp_xy_unsig = \
self.pose_embed[dec_lid - self.num_box_decoder_layers](layer_hs_body_kpt)
layer_ref_sig_body_kpt = \
layer_ref_sig[:,effective_dn_number:, :].index_select(1,torch.tensor(body_kpt_index,device=layer_hs.device))
layer_outputs_unsig_body_keypoints = delta_body_kp_xy_unsig + inverse_sigmoid(
layer_ref_sig_body_kpt[..., :2])
vis_xy_unsig = torch.ones_like(
layer_outputs_unsig_body_keypoints,
device=layer_outputs_unsig_body_keypoints.device)
xyv = torch.cat((layer_outputs_unsig_body_keypoints,
vis_xy_unsig[:, :, 0].unsqueeze(-1)),
dim=-1)
xyv = xyv.sigmoid()
# from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
# img =(data_batch['img'][0].permute(1,2,0)*255).int().cpu().numpy()
# gt_kp2d = xyv[0,:17]
# coco_kps = gt_kp2d[:,:2].reshape(17,2).detach().cpu().numpy() * data_batch['img_shape'].cpu().numpy()[0,None,None,::-1]
# visualize_kp2d(
# coco_kps,
# output_path='.',
# image_array=img.copy()[None],
# data_source='coco',
# overwrite=True)
layer_res = xyv.reshape(
(bs, self.num_group, self.num_body_points,
3)).flatten(2, 3)
layer_hw = layer_ref_sig_body_kpt[..., 2:].reshape(
bs, self.num_group, self.num_body_points, 2).flatten(2, 3)
layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res)
outputs_body_keypoints_list.append(layer_res)
outputs_body_keypoints_hw.append(layer_hw)
# lhand bbox
layer_hs_lhand_bbox = \
layer_hs[:, effective_dn_number:, :][:, (self.num_body_points + 1)::(self.num_body_points + 4), :]
delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox)
layer_ref_sig_lhand_bbox = \
layer_ref_sig[:,effective_dn_number:, :][
:, (self.num_body_points + 1)::(self.num_body_points + 4), :].clone()
layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox)
delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox)
layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2]
layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig
layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid()
outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox)
layer_lhand_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_hand_points * 3)) # [-, 900, 42]
outputs_lhand_keypoints_list.append(layer_lhand_kps_res)
# rhand bbox
layer_hs_rhand_bbox = \
layer_hs[:, effective_dn_number:, :][
:, (self.num_body_points + 2)::(self.num_body_points + 4), :]
delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox)
layer_ref_sig_rhand_bbox = \
layer_ref_sig[:,effective_dn_number:, :][
:, (self.num_body_points + 2)::(self.num_body_points + 4), :].clone()
layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox)
delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox)
layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2]
layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig
layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid()
outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox)
# rhand kps
layer_rhand_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_hand_points * 3)) # [-, 900, 42]
outputs_rhand_keypoints_list.append(layer_rhand_kps_res)
# face bbox
layer_hs_face_bbox = \
layer_hs[:, effective_dn_number:, :][
:, (self.num_body_points + 3)::(self.num_body_points + 4), :]
delta_face_bbox_xy_unsig = self.bbox_face_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox)
layer_ref_sig_face_bbox = \
layer_ref_sig[:,effective_dn_number:, :][
:, (self.num_body_points + 3)::(self.num_body_points + 4), :].clone()
layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox)
delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox)
layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2]
layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig
layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid()
outputs_face_bbox_list.append(layer_ref_sig_face_bbox)
# face kps
layer_face_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_face_points * 3)) # [-, 900, 42]
outputs_face_keypoints_list.append(layer_face_kps_res)
# smpl or smplx
bs, _, feat_dim = layer_hs.shape
smpl_feats = layer_hs[:, effective_dn_number:, :].reshape(
bs, -1, feat_dim * (self.num_body_points + 4))
smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :][
:, (self.num_body_points + 1):: (self.num_body_points + 4), :].reshape(
bs, -1, feat_dim)
smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :][
:, (self.num_body_points + 2):: (self.num_body_points + 4), :].reshape(
bs, -1, feat_dim)
smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :][
:, (self.num_body_points + 3):: (self.num_body_points + 4), :].reshape(
bs, -1, feat_dim)
smpl_pose = self.smpl_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_feats)
smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape(
bs, self.num_group, self.body_model_joint_num, 3, 3)
smpl_lhand_pose = self.smpl_hand_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats)
smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape(
-1, 6)).reshape(bs, self.num_group, 15, 3, 3)
smpl_rhand_pose = self.smpl_hand_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats)
smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape(
-1, 6)).reshape(bs, self.num_group, 15, 3, 3)
smpl_jaw_pose = self.smpl_jaw_embed[
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats)
smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape(
bs, self.num_group, 1, 3, 3)
smpl_beta = self.smpl_beta_embed[
dec_lid - self.num_box_decoder_layers](smpl_feats)
smpl_cam = self.smpl_cam_embed[
dec_lid - self.num_box_decoder_layers](smpl_feats)
# smpl_cam_f = self.smpl_cam_f_embed[
# dec_lid - self.num_box_decoder_layers](smpl_feats)
# zero
# smpl_lhand_pose = layer_hs.new_zeros(bs, self.num_group, 15, 3, 3)
# smpl_rhand_pose = layer_hs.new_zeros(bs, self.num_group, 15, 3, 3)
# smpl_expr = layer_hs.new_zeros(bs, self.num_group, 10)
smpl_expr = self.smpl_expr_embed[
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats)
# smpl_jaw_pose = layer_hs.new_zeros(bs, self.num_group, 3)
leye_pose = torch.zeros_like(smpl_jaw_pose)
reye_pose = torch.zeros_like(smpl_jaw_pose)
if self.body_model is not None:
smpl_pose_ = rotmat_to_aa(smpl_pose)
# smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose)
# smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose)
smpl_lhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3)
smpl_rhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3)
smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose)
leye_pose_ = rotmat_to_aa(leye_pose)
reye_pose_ = rotmat_to_aa(reye_pose)
pred_output = self.body_model(
betas=smpl_beta.reshape(-1, 10),
body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3),
global_orient=smpl_pose_[:, :, 0].reshape(
-1, 3).unsqueeze(1),
left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3),
right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3),
leye_pose=leye_pose_,
reye_pose=reye_pose_,
jaw_pose=smpl_jaw_pose_.reshape(-1, 3),
# expression=smpl_expr.reshape(-1, 10),
expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10)
)
smpl_kp3d = pred_output['joints'].reshape(
bs, self.num_group, -1, 3)
smpl_verts = pred_output['vertices'].reshape(
bs, self.num_group, -1, 3)
# pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3)
outputs_smpl_pose_list.append(smpl_pose)
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose)
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose)
outputs_smpl_expr_list.append(smpl_expr)
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose)
outputs_smpl_beta_list.append(smpl_beta)
outputs_smpl_cam_list.append(smpl_cam)
# outputs_smpl_cam_f_list.append(smpl_cam_f)
outputs_smpl_kp3d_list.append(smpl_kp3d)
else:
bs = layer_ref_sig.shape[0]
layer_hs_body_kpt = \
layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(body_kpt_index_2, device=layer_hs.device))
# body kp2d
delta_body_kp_xy_unsig = \
self.pose_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_body_kpt)
layer_ref_sig_body_kpt = \
layer_ref_sig[:,effective_dn_number:, :].index_select(
1,torch.tensor(body_kpt_index_2,device=layer_hs.device))
layer_outputs_unsig_body_keypoints = \
delta_body_kp_xy_unsig + inverse_sigmoid(
layer_ref_sig_body_kpt[..., :2])
vis_xy_unsig = torch.ones_like(
layer_outputs_unsig_body_keypoints,
device=layer_outputs_unsig_body_keypoints.device)
xyv = torch.cat((layer_outputs_unsig_body_keypoints,
vis_xy_unsig[:, :, 0].unsqueeze(-1)),
dim=-1)
xyv = xyv.sigmoid()
layer_res = xyv.reshape(
(bs, self.num_group, self.num_body_points,
3)).flatten(2, 3)
layer_hw = layer_ref_sig_body_kpt[..., 2:].reshape(
bs, self.num_group, self.num_body_points, 2).flatten(2, 3)
layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res)
outputs_body_keypoints_list.append(layer_res)
outputs_body_keypoints_hw.append(layer_hw)
# lhand bbox
layer_hs_lhand_bbox = \
layer_hs[:, effective_dn_number:, :][
:, (self.num_body_points + 1)::(self.num_whole_body_points + 4), :]
delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox)
layer_ref_sig_lhand_bbox = \
layer_ref_sig[:,effective_dn_number:, :][
:, (self.num_body_points + 1)::(self.num_whole_body_points + 4), :].clone()
layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox)
delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox)
layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2]
layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig
layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid()
outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox)
# lhand kps
layer_hs_lhand_kps_res = \
layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(lhand_kpt_index, device=layer_hs.device))
delta_lhand_kp_xy_unsig = \
self.pose_hand_embed[
dec_lid - self.num_hand_face_decoder_layers](layer_hs_lhand_kps_res)
layer_ref_sig_lhand_kpt = \
layer_ref_sig[:,effective_dn_number:, :].index_select(
1,torch.tensor(lhand_kpt_index,device=layer_hs.device))
layer_outputs_unsig_lhand_keypoints = delta_lhand_kp_xy_unsig + inverse_sigmoid(
layer_ref_sig_lhand_kpt[..., :2])
lhand_vis_xy_unsig = torch.ones_like(
layer_outputs_unsig_lhand_keypoints,
device=layer_outputs_unsig_lhand_keypoints.device)
lhand_xyv = torch.cat((layer_outputs_unsig_lhand_keypoints,
lhand_vis_xy_unsig[:, :, 0].unsqueeze(-1)),
dim=-1)
lhand_xyv = lhand_xyv.sigmoid()
layer_lhand_kps_res = lhand_xyv.reshape(
(bs, self.num_group, self.num_hand_points,
3)).flatten(2, 3)
layer_lhand_hw = layer_ref_sig_lhand_kpt[..., 2:].reshape(
bs, self.num_group, self.num_hand_points, 2).flatten(2, 3)
layer_lhand_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_lhand_kps_res)
outputs_lhand_keypoints_list.append(layer_lhand_kps_res)
outputs_lhand_keypoints_hw.append(layer_lhand_hw)
# rhand bbox
layer_hs_rhand_bbox = \
layer_hs[:, effective_dn_number:, :][
:, (self.num_body_points + self.num_hand_points + 2)::(self.num_whole_body_points + 4), :]
delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox)
layer_ref_sig_rhand_bbox = \
layer_ref_sig[:,effective_dn_number:, :][
:, (self.num_body_points + self.num_hand_points + 2)::(self.num_whole_body_points + 4), :].clone()
layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox)
delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox)
layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2]
layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig
layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid()
outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox)
# rhand kps
layer_hs_rhand_kps_res = \
layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(rhand_kpt_index, device=layer_hs.device))
delta_rhand_kp_xy_unsig = \
self.pose_hand_embed[
dec_lid - self.num_hand_face_decoder_layers](layer_hs_rhand_kps_res)
layer_ref_sig_rhand_kpt = \
layer_ref_sig[:,effective_dn_number:, :].index_select(
1,torch.tensor(rhand_kpt_index,device=layer_hs.device))
layer_outputs_unsig_rhand_keypoints = delta_rhand_kp_xy_unsig + inverse_sigmoid(
layer_ref_sig_rhand_kpt[..., :2])
rhand_vis_xy_unsig = torch.ones_like(
layer_outputs_unsig_rhand_keypoints,
device=layer_outputs_unsig_rhand_keypoints.device)
rhand_xyv = torch.cat((layer_outputs_unsig_rhand_keypoints,
rhand_vis_xy_unsig[:, :, 0].unsqueeze(-1)),
dim=-1)
rhand_xyv = rhand_xyv.sigmoid()
layer_rhand_kps_res = rhand_xyv.reshape(
(bs, self.num_group, self.num_hand_points,
3)).flatten(2, 3)
layer_rhand_hw = layer_ref_sig_rhand_kpt[..., 2:].reshape(
bs, self.num_group, self.num_hand_points, 2).flatten(2, 3)
layer_rhand_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_rhand_kps_res)
outputs_rhand_keypoints_list.append(layer_rhand_kps_res)
outputs_rhand_keypoints_hw.append(layer_rhand_hw)
# face bbox
layer_hs_face_bbox = \
layer_hs[:, effective_dn_number:, :][
:, (self.num_body_points + 2 * self.num_hand_points + 3)::(self.num_whole_body_points + 4), :]
delta_face_bbox_xy_unsig = self.bbox_face_embed[dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox)
layer_ref_sig_face_bbox = \
layer_ref_sig[:,effective_dn_number:, :][
:, (self.num_body_points + 2 * self.num_hand_points + 3)::(self.num_whole_body_points + 4), :].clone()
layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox)
delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox)
layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2]
layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig
layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid()
outputs_face_bbox_list.append(layer_ref_sig_face_bbox)
# face kps
layer_hs_face_kps_res = \
layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(face_kpt_index, device=layer_hs.device))
delta_face_kp_xy_unsig = \
self.pose_face_embed[
dec_lid - self.num_hand_face_decoder_layers](layer_hs_face_kps_res)
layer_ref_sig_face_kpt = \
layer_ref_sig[:,effective_dn_number:, :].index_select(
1,torch.tensor(face_kpt_index,device=layer_hs.device))
layer_outputs_unsig_face_keypoints = delta_face_kp_xy_unsig + inverse_sigmoid(
layer_ref_sig_face_kpt[..., :2])
face_vis_xy_unsig = torch.ones_like(
layer_outputs_unsig_face_keypoints,
device=layer_outputs_unsig_face_keypoints.device)
face_xyv = torch.cat((layer_outputs_unsig_face_keypoints,
face_vis_xy_unsig[:, :, 0].unsqueeze(-1)),
dim=-1)
face_xyv = face_xyv.sigmoid()
layer_face_kps_res = face_xyv.reshape(
(bs, self.num_group, self.num_face_points,
3)).flatten(2, 3)
layer_face_hw = layer_ref_sig_face_kpt[..., 2:].reshape(
bs, self.num_group, self.num_face_points, 2).flatten(2, 3)
layer_face_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_face_kps_res)
outputs_face_keypoints_list.append(layer_face_kps_res)
outputs_face_keypoints_hw.append(layer_face_hw)
# pdb.set_trace()
bs, _, feat_dim = layer_hs.shape
smpl_body_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * (self.num_body_points + 4))
smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * (self.num_hand_points + 3))
smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * (self.num_hand_points + 3))
smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_face_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * (self.num_face_points + 2))
smpl_pose = self.smpl_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats)
smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape(
bs, self.num_group, self.body_model_joint_num, 3, 3)
smpl_lhand_pose = self.smpl_hand_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats)
smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape(
-1, 6)).reshape(bs, self.num_group, 15, 3, 3)
smpl_rhand_pose = self.smpl_hand_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats)
smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape(
-1, 6)).reshape(bs, self.num_group, 15, 3, 3)
smpl_expr = self.smpl_expr_embed[
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats)
smpl_jaw_pose = self.smpl_jaw_embed[
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats)
smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape(
bs, self.num_group, 1, 3, 3)
smpl_beta = self.smpl_beta_embed[
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats)
smpl_cam = self.smpl_cam_embed[
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats)
# smpl_cam_f = self.smpl_cam_f_embed[
# dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats)
num_samples = smpl_beta.reshape(-1, 10).shape[0]
device = smpl_beta.device
leye_pose = torch.zeros_like(smpl_jaw_pose)
reye_pose = torch.zeros_like(smpl_jaw_pose)
if self.body_model is not None:
# print(smpl_pose)
# exit()
smpl_pose_ = rotmat_to_aa(smpl_pose)
smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose)
smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose)
smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose)
leye_pose_ = rotmat_to_aa(leye_pose)
reye_pose_ = rotmat_to_aa(reye_pose)
pred_output = self.body_model(
betas=smpl_beta.reshape(-1, 10),
body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3),
global_orient=smpl_pose_[:, :, 0].reshape(
-1, 3).unsqueeze(1),
left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3),
right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3),
leye_pose=leye_pose_,
reye_pose=reye_pose_,
jaw_pose=smpl_jaw_pose_.reshape(-1, 3),
expression=smpl_expr.reshape(-1, 10),
# expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10),
)
smpl_kp3d = pred_output['joints'].reshape(
bs, self.num_group, -1, 3)
smpl_verts = pred_output['vertices'].reshape(
bs, self.num_group, -1, 3)
# pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3)
# from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d
# visualize_kp3d(smpl_kp3d[0,:100].detach().cpu().numpy(),
# output_path='./figs/pred3d',
# data_source='smplx_137')
# import numpy as np
# from pytorch3d.io import save_obj
# save_obj(
# '1.obj',
# torch.tensor(pred_output['vertices'][0]),
# torch.tensor(self.body_model.faces.astype(np.float)))
# exit()
outputs_smpl_pose_list.append(smpl_pose)
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose)
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose)
outputs_smpl_expr_list.append(smpl_expr)
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose)
outputs_smpl_beta_list.append(smpl_beta)
outputs_smpl_cam_list.append(smpl_cam)
# outputs_smpl_cam_f_list.append(smpl_cam_f)
outputs_smpl_kp3d_list.append(smpl_kp3d)
if not self.training:
outputs_smpl_verts_list.append(smpl_verts)
dn_mask_dict = mask_dict
if self.dn_number > 0 and dn_mask_dict is not None:
outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list = self.dn_post_process2(
outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list,
dn_mask_dict)
dn_class_input = dn_mask_dict['known_labels']
dn_bbox_input = dn_mask_dict['known_bboxs']
dn_class_pred = dn_mask_dict['output_known_class']
dn_bbox_pred = dn_mask_dict['output_known_coord']
for idx, (_out_class, _out_bbox, _out_keypoint) in enumerate(
zip(outputs_class, outputs_body_bbox_list,
outputs_body_keypoints_list)):
assert _out_class.shape[1] == _out_bbox.shape[
1] == _out_keypoint.shape[1]
out = {
'pred_logits': outputs_class[-1],
'pred_boxes': outputs_body_bbox_list[-1],
'pred_lhand_boxes': outputs_lhand_bbox_list[-1],
'pred_rhand_boxes': outputs_rhand_bbox_list[-1],
'pred_face_boxes': outputs_face_bbox_list[-1],
'pred_keypoints': outputs_body_keypoints_list[-1],
'pred_lhand_keypoints': outputs_lhand_keypoints_list[-1],
'pred_rhand_keypoints': outputs_rhand_keypoints_list[-1],
'pred_face_keypoints': outputs_face_keypoints_list[-1],
'pred_smpl_pose': outputs_smpl_pose_list[-1],
'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1],
'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1],
'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1],
'pred_smpl_expr': outputs_smpl_expr_list[-1],
'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10]
'pred_smpl_cam': outputs_smpl_cam_list[-1],
# 'pred_smpl_cam_f': outputs_smpl_cam_f_list[-1],
'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1]
}
if not self.training:
full_pose = torch.cat((outputs_smpl_pose_list[-1],
outputs_smpl_lhand_pose_list[-1],
outputs_smpl_rhand_pose_list[-1],
outputs_smpl_jaw_pose_list[-1]),dim=2)
bs,num_q,_,_,_ = full_pose.shape
full_pose = rotmat_to_aa(full_pose).reshape(bs,num_q,53*3)
out = {
'pred_logits': outputs_class[-1],
'pred_boxes': outputs_body_bbox_list[-1],
'pred_lhand_boxes': outputs_lhand_bbox_list[-1],
'pred_rhand_boxes': outputs_rhand_bbox_list[-1],
'pred_face_boxes': outputs_face_bbox_list[-1],
'pred_keypoints': outputs_body_keypoints_list[-1],
'pred_lhand_keypoints': outputs_lhand_keypoints_list[-1],
'pred_rhand_keypoints': outputs_rhand_keypoints_list[-1],
'pred_face_keypoints': outputs_face_keypoints_list[-1],
'pred_smpl_pose': outputs_smpl_pose_list[-1],
'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1],
'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1],
'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1],
'pred_smpl_expr': outputs_smpl_expr_list[-1],
'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10]
'pred_smpl_cam': outputs_smpl_cam_list[-1],
# 'pred_smpl_cam_f': outputs_smpl_cam_f_list[-1],
'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1],
'pred_smpl_verts': outputs_smpl_verts_list[-1],
'pred_smpl_fullpose': full_pose
}
if self.dn_number > 0 and dn_mask_dict is not None:
out.update({
'dn_class_input': dn_class_input,
'dn_bbox_input': dn_bbox_input,
'dn_class_pred': dn_class_pred[-1],
'dn_bbox_pred': dn_bbox_pred[-1],
'num_tgt': dn_mask_dict['pad_size']
})
if self.aux_loss:
out['aux_outputs'] = \
self._set_aux_loss(
outputs_class,
outputs_body_bbox_list,
outputs_lhand_bbox_list,
outputs_rhand_bbox_list,
outputs_face_bbox_list,
outputs_body_keypoints_list,
outputs_lhand_keypoints_list,
outputs_rhand_keypoints_list,
outputs_face_keypoints_list,
outputs_smpl_pose_list,
outputs_smpl_rhand_pose_list,
outputs_smpl_lhand_pose_list,
outputs_smpl_jaw_pose_list,
outputs_smpl_expr_list,
outputs_smpl_beta_list,
outputs_smpl_cam_list,
# outputs_smpl_cam_f_list,
outputs_smpl_kp3d_list
) # with key pred_logits, pred_bbox, pred_keypoints
if self.dn_number > 0 and dn_mask_dict is not None:
assert len(dn_class_pred[:-1]) == len(
dn_bbox_pred[:-1]) == len(out['aux_outputs'])
for aux_out, dn_class_pred_i, dn_bbox_pred_i in zip(
out['aux_outputs'], dn_class_pred, dn_bbox_pred):
aux_out.update({
'dn_class_input': dn_class_input,
'dn_bbox_input': dn_bbox_input,
'dn_class_pred': dn_class_pred_i,
'dn_bbox_pred': dn_bbox_pred_i,
'num_tgt': dn_mask_dict['pad_size']
})
# for encoder output
if hs_enc is not None:
interm_coord = ref_enc[-1]
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1])
interm_pose = torch.zeros_like(outputs_body_keypoints_list[0])
out['interm_outputs'] = {
'pred_logits': interm_class,
'pred_boxes': interm_coord,
'pred_keypoints': interm_pose
}
return out, targets, data_batch
@torch.jit.unused
def _set_aux_loss(self,
outputs_class,
outputs_body_coord,
outputs_lhand_coord,
outputs_rhand_coord,
outputs_face_coord,
outputs_body_keypoints,
outputs_lhand_keypoints,
outputs_rhand_keypoints,
outputs_face_keypoints,
outputs_smpl_pose,
outputs_smpl_rhand_pose,
outputs_smpl_lhand_pose,
outputs_smpl_jaw_pose,
outputs_smpl_expr,
outputs_smpl_beta,
outputs_smpl_cam,
# outputs_smpl_cam_f,
outputs_smpl_kp3d):
return [{
'pred_logits': a,
'pred_boxes': b,
'pred_lhand_boxes': c,
'pred_rhand_boxes': d,
'pred_face_boxes': e,
'pred_keypoints': f,
'pred_lhand_keypoints': g,
'pred_rhand_keypoints': h,
'pred_face_keypoints': i,
'pred_smpl_pose': j,
'pred_smpl_rhand_pose': k,
'pred_smpl_lhand_pose': l,
'pred_smpl_jaw_pose': m,
'pred_smpl_expr': n,
'pred_smpl_beta': o,
'pred_smpl_cam': p,
# 'pred_smpl_cam_f': q,
'pred_smpl_kp3d': q
} for a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q in zip(
outputs_class[:-1],
outputs_body_coord[:-1],
outputs_lhand_coord[:-1],
outputs_rhand_coord[:-1],
outputs_face_coord[:-1],
outputs_body_keypoints[:-1],
outputs_lhand_keypoints[:-1],
outputs_rhand_keypoints[:-1],
outputs_face_keypoints[:-1],
outputs_smpl_pose[:-1],
outputs_smpl_rhand_pose[:-1],
outputs_smpl_lhand_pose[:-1],
outputs_smpl_jaw_pose[:-1],
outputs_smpl_expr[:-1],
outputs_smpl_beta[:-1],
outputs_smpl_cam[:-1],
outputs_smpl_kp3d[:-1])]
def prepare_targets(self, data_batch):
data_batch_coco = []
instance_dict = {}
img_list = data_batch['img'].float()
# input_img_h, input_img_w = data_batch['image_metas'][0]['batch_input_shape']
batch_size, _, input_img_h, input_img_w = img_list.shape
device = img_list.device
masks = torch.ones((batch_size, input_img_h, input_img_w),
dtype=torch.bool,
device=device)
# cv2.imread(data_batch['img_metas'][img_id]['image_path']).shape
for img_id in range(batch_size):
img_h, img_w = data_batch['img_shape'][img_id]
masks[img_id, :img_h, :img_w] = 0
if not self.inference:
instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\
data_batch['body_bbox_size'][img_id]],dim=-1)
instance_face_bbox = torch.cat([data_batch['face_bbox_center'][img_id],\
data_batch['face_bbox_size'][img_id]],dim=-1)
instance_lhand_bbox = torch.cat([data_batch['lhand_bbox_center'][img_id],\
data_batch['lhand_bbox_size'][img_id]],dim=-1)
instance_rhand_bbox = torch.cat([data_batch['rhand_bbox_center'][img_id],\
data_batch['rhand_bbox_size'][img_id]],dim=-1)
instance_kp2d = data_batch['joint_img'][img_id].clone().float()
instance_kp2d_mask = data_batch['joint_trunc'][img_id].clone().float()
instance_kp2d[:,:,2:] = instance_kp2d_mask
body_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'coco', approximate=True)
lhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_lhand', approximate=True)
rhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_rhand', approximate=True)
face_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_face', approximate=True)
# from util.vis_utils import show_bbox
# show_bbox(img_list[img_id],instance_kp2d.cpu().numpy(),data_batch['bbox_xywh'][img_id].cpu().numpy)
body_kp2d[:,:,0] = body_kp2d[:,:,0]/cfg.output_hm_shape[2]
body_kp2d[:,:,1] = body_kp2d[:,:,1]/cfg.output_hm_shape[1]
body_kp2d = torch.cat([body_kp2d[:,:,:2].flatten(1),body_kp2d[:,:,2]],dim=-1)
lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2]
lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1]
lhand_kp2d = torch.cat([lhand_kp2d[:,:,:2].flatten(1),lhand_kp2d[:,:,2]],dim=-1)
rhand_kp2d[:,:,0] = rhand_kp2d[:,:,0]/cfg.output_hm_shape[2]
rhand_kp2d[:,:,1] = rhand_kp2d[:,:,1]/cfg.output_hm_shape[1]
rhand_kp2d = torch.cat([rhand_kp2d[:,:,:2].flatten(1),rhand_kp2d[:,:,2]],dim=-1)
face_kp2d[:,:,0] = face_kp2d[:,:,0]/cfg.output_hm_shape[2]
face_kp2d[:,:,1] = face_kp2d[:,:,1]/cfg.output_hm_shape[1]
face_kp2d = torch.cat([face_kp2d[:,:,:2].flatten(1),face_kp2d[:,:,2]],dim=-1)
instance_dict = {}
instance_dict['boxes'] = instance_body_bbox.float()
instance_dict['face_boxes'] = instance_face_bbox.float()
instance_dict['lhand_boxes'] = instance_lhand_bbox.float()
instance_dict['rhand_boxes'] = instance_rhand_bbox.float()
instance_dict['keypoints'] = body_kp2d.float()
instance_dict['lhand_keypoints'] = lhand_kp2d.float()
instance_dict['rhand_keypoints'] = rhand_kp2d.float()
instance_dict['face_keypoints'] = face_kp2d.float()
# instance_dict['orig_size'] = data_batch['ori_shape'][img_id]
instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation
instance_dict['area'] = instance_body_bbox[:, 2] * instance_body_bbox[:, 3]
instance_dict['lhand_area'] = instance_lhand_bbox[:, 2] * instance_lhand_bbox[:, 3]
instance_dict['rhand_area'] = instance_rhand_bbox[:, 2] * instance_rhand_bbox[:, 3]
instance_dict['face_area'] = instance_face_bbox[:, 2] * instance_face_bbox[:, 3]
instance_dict['labels'] = torch.ones(instance_body_bbox.shape[0],
dtype=torch.long,
device=device)
data_batch_coco.append(instance_dict)
# body_bbox = data_batch['body_bbox'][img_id].clone().float().reshape(-1, 4)
# lhand_bbox = data_batch['lhand_bbox'][img_id].clone().float().reshape(-1, 4)
# rhand_bbox = data_batch['rhand_bbox'][img_id].clone().float().reshape(-1, 4)
# face_bbox = data_batch['face_bbox'][img_id].clone().float().reshape(-1, 4)
# vis = False
# if vis:
# import mmcv
# body_bbox[:, 0] *= img_w
# body_bbox[:, 1] *= img_h
# body_bbox[:, 2] *= img_w
# body_bbox[:, 3] *= img_h
# img = (data_batch['img'][img_id]*255).int().permute(1,2,0).cpu().detach().numpy()
# img = mmcv.imshow_bboxes(img.copy(), face_bbox.cpu().numpy(), show=False)
# cv2.imwrite('test.png', img)
# instance_kp2d[:,:,0] = instance_kp2d[:,:,0]/cfg.output_hm_shape[2]*img_w
# instance_kp2d[:,:,1] = instance_kp2d[:,:,1]/cfg.output_hm_shape[1]*img_h
# from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
# img = (data_batch['img'][img_id]*255).int().permute(1,2,0).cpu().detach().numpy()
# img1 = visualize_kp2d(instance_kp2d.cpu().detach().numpy(),image_array=img[None].copy(),return_array=True)
# cv2.imwrite('test.png',img1[0])
# lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2]*img_w
# lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1]*img_h
# lhand_kp2d = convert_kps(lhand_kp2d, 'smplx_lhand', 'smplx', approximate=True)[0]
else:
instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\
data_batch['body_bbox_size'][img_id]],dim=-1)
instance_dict = {}
# instance_dict['orig_size'] = data_batch['ori_shape'][img_id]
instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation
instance_dict['boxes'] = instance_body_bbox.float()
data_batch_coco.append(instance_dict)
input_img = NestedTensor(img_list, masks)
return input_img, data_batch_coco
def keypoints_to_scaled_bbox_bfh(
self, keypoints, occ=None,
body_scale=1.0, fh_scale=1.0,
convention='smplx'):
'''Obtain scaled bbox in xyxy format given keypoints
Args:
keypoints (np.ndarray): Keypoints
scale (float): Bounding Box scale
Returns:
bbox_xyxy (np.ndarray): Bounding box in xyxy format
'''
bboxs = []
# supported kps.shape: (1, n, k) or (n, k), k = 2 or 3
if keypoints.ndim == 3:
keypoints = keypoints[0]
if keypoints.shape[-1] != 2:
keypoints = keypoints[:, :2]
for body_part in ['body', 'head', 'left_hand', 'right_hand']:
if body_part == 'body':
scale = body_scale
kps = keypoints
else:
scale = fh_scale
kp_id = get_keypoint_idxs_by_part(body_part, convention=convention)
kps = keypoints[kp_id]
if not occ is None:
occ_p = occ[kp_id]
if np.sum(occ_p) / len(kp_id) >= 0.1:
conf = 0
# print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip')
else:
# print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}')
conf = 1
else:
conf = 1
if body_part == 'body':
conf = 1
xmin, ymin = np.amin(kps, axis=0)
xmax, ymax = np.amax(kps, axis=0)
width = (xmax - xmin) * scale
height = (ymax - ymin) * scale
x_center = 0.5 * (xmax + xmin)
y_center = 0.5 * (ymax + ymin)
xmin = x_center - 0.5 * width
xmax = x_center + 0.5 * width
ymin = y_center - 0.5 * height
ymax = y_center + 0.5 * height
bbox = np.stack([xmin, ymin, xmax, ymax, conf], axis=0).astype(np.float32)
bboxs.append(bbox)
return bboxs
@MODULE_BUILD_FUNCS.registe_with_name(module_name='aios_smplx')
def build_aios_smplx(args, cfg):
# pdb.set_trace()
num_classes = args.num_classes # 2
device = torch.device(args.device)
backbone = build_backbone(args)
transformer = build_transformer(args)
dn_labelbook_size = args.dn_labelbook_size
dec_pred_class_embed_share = args.dec_pred_class_embed_share
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
if args.eval:
body_model = args.body_model_test
train = False
else:
body_model = args.body_model_train
train = True
model = AiOSSMPLX(
backbone,
transformer,
num_classes=num_classes, # 2
num_queries=args.num_queries, # 900
aux_loss=True,
iter_update=True,
query_dim=4,
random_refpoints_xy=args.random_refpoints_xy, # False
fix_refpoints_hw=args.fix_refpoints_hw, # -1
num_feature_levels=args.num_feature_levels, # 4
nheads=args.nheads, # 8
dec_pred_class_embed_share=dec_pred_class_embed_share, # false
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # False
# two stage
two_stage_type=args.two_stage_type,
# box_share
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, # False
two_stage_class_embed_share=args.two_stage_class_embed_share, # False
dn_number=args.dn_number if args.use_dn else 0, # 100
dn_box_noise_scale=args.dn_box_noise_scale, # 0.4
dn_label_noise_ratio=args.dn_label_noise_ratio, # 0.5
dn_batch_gt_fuse=args.dn_batch_gt_fuse, # false
dn_attn_mask_type_list=args.dn_attn_mask_type_list,
dn_labelbook_size=dn_labelbook_size, # 100
cls_no_bias=args.cls_no_bias, # False
num_group=args.num_group, # 100
num_body_points=args.num_body_points, # 17
num_hand_points=args.num_hand_points, # 17
num_face_points=args.num_face_points, # 17
num_box_decoder_layers=args.num_box_decoder_layers, # 2
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers,
# smpl_convention=convention
body_model=body_model,
train=train,
inference=args.inference)
matcher = build_matcher(args)
# prepare weight dict
weight_dict = {
'loss_ce': args.cls_loss_coef, # 2
# bbox
'loss_body_bbox': args.body_bbox_loss_coef, # 5
'loss_rhand_bbox': args.rhand_bbox_loss_coef, # 5
'loss_lhand_bbox': args.lhand_bbox_loss_coef, # 5
'loss_face_bbox': args.face_bbox_loss_coef, # 5
# bbox giou
'loss_body_giou': args.body_giou_loss_coef, # 2
'loss_rhand_giou': args.rhand_giou_loss_coef, # 2
'loss_lhand_giou': args.lhand_giou_loss_coef, # 2
'loss_face_giou': args.face_giou_loss_coef, # 2
# 2d kp
'loss_keypoints': args.keypoints_loss_coef, # 10
'loss_rhand_keypoints': args.rhand_keypoints_loss_coef, # 10
'loss_lhand_keypoints': args.lhand_keypoints_loss_coef, # 10
'loss_face_keypoints': args.face_keypoints_loss_coef, # 10
# 2d kp oks
'loss_oks': args.oks_loss_coef, # 4
'loss_rhand_oks': args.rhand_oks_loss_coef, # 4
'loss_lhand_oks': args.lhand_oks_loss_coef, # 4
'loss_face_oks': args.face_oks_loss_coef, # 4
# smpl param
'loss_smpl_pose_root': args.smpl_pose_loss_root_coef, # 0
'loss_smpl_pose_body': args.smpl_pose_loss_body_coef, # 0
'loss_smpl_pose_lhand': args.smpl_pose_loss_lhand_coef, # 0
'loss_smpl_pose_rhand': args.smpl_pose_loss_rhand_coef, # 0
'loss_smpl_pose_jaw': args.smpl_pose_loss_jaw_coef, # 0
'loss_smpl_beta': args.smpl_beta_loss_coef, # 0
'loss_smpl_expr': args.smpl_expr_loss_coef,
# smpl kp3d ra
'loss_smpl_body_kp3d_ra': args.smpl_body_kp3d_ra_loss_coef, # 0
'loss_smpl_lhand_kp3d_ra': args.smpl_lhand_kp3d_ra_loss_coef, # 0
'loss_smpl_rhand_kp3d_ra': args.smpl_rhand_kp3d_ra_loss_coef, # 0
'loss_smpl_face_kp3d_ra': args.smpl_face_kp3d_ra_loss_coef, # 0
# smpl kp3d
'loss_smpl_body_kp3d': args.smpl_body_kp3d_loss_coef, # 0
'loss_smpl_face_kp3d': args.smpl_face_kp3d_loss_coef, # 0
'loss_smpl_lhand_kp3d': args.smpl_lhand_kp3d_loss_coef, # 0
'loss_smpl_rhand_kp3d': args.smpl_rhand_kp3d_loss_coef, # 0
# smpl kp2d
'loss_smpl_body_kp2d': args.smpl_body_kp2d_loss_coef, # 0
'loss_smpl_lhand_kp2d': args.smpl_lhand_kp2d_loss_coef, # 0
'loss_smpl_rhand_kp2d': args.smpl_rhand_kp2d_loss_coef, # 0
'loss_smpl_face_kp2d': args.smpl_face_kp2d_loss_coef, # 0
# smpl kp2d ba
'loss_smpl_body_kp2d_ba': args.smpl_body_kp2d_ba_loss_coef,
'loss_smpl_face_kp2d_ba': args.smpl_face_kp2d_ba_loss_coef,
'loss_smpl_lhand_kp2d_ba': args.smpl_lhand_kp2d_ba_loss_coef,
'loss_smpl_rhand_kp2d_ba': args.smpl_rhand_kp2d_ba_loss_coef,
}
clean_weight_dict_wo_dn = copy.deepcopy(weight_dict)
if args.use_dn:
weight_dict.update({
'dn_loss_ce':
args.dn_label_coef, # 0.3
'dn_loss_bbox':
args.bbox_loss_coef * args.dn_bbox_coef, # 5 * 0.5
'dn_loss_giou':
args.giou_loss_coef * args.dn_bbox_coef, # 2 * 0.5
})
clean_weight_dict = copy.deepcopy(weight_dict)
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1): # from 0 t 4 # ???
for k, v in clean_weight_dict.items():
if i < args.num_box_decoder_layers and ('keypoints' in k or 'oks' in k):
continue
if i < args.num_box_decoder_layers and k in [
'loss_rhand_bbox', 'loss_lhand_bbox', 'loss_face_bbox',
'loss_rhand_giou', 'loss_lhand_giou', 'loss_face_giou']:
continue
if i < args.num_hand_face_decoder_layers and k in [
'loss_rhand_keypoints', 'loss_lhand_keypoints',
'loss_face_keypoints', 'loss_rhand_oks',
'loss_lhand_oks', 'loss_face_oks']:
continue
if i < args.num_box_decoder_layers and 'smpl' in k:
continue
aux_weight_dict.update({k + f'_{i}': v})
weight_dict.update(aux_weight_dict)
if args.two_stage_type != 'no':
interm_weight_dict = {}
try:
no_interm_box_loss = args.no_interm_box_loss
except:
no_interm_box_loss = False
_coeff_weight_dict = {
'loss_ce': 1.0,
# bbox
'loss_body_bbox': 1.0 if not no_interm_box_loss else 0.0,
'loss_rhand_bbox': 1.0 if not no_interm_box_loss else 0.0,
'loss_lhand_bbox': 1.0 if not no_interm_box_loss else 0.0,
'loss_face_bbox': 1.0 if not no_interm_box_loss else 0.0,
# bbox giou
'loss_body_giou': 1.0 if not no_interm_box_loss else 0.0,
'loss_rhand_giou': 1.0 if not no_interm_box_loss else 0.0,
'loss_lhand_giou': 1.0 if not no_interm_box_loss else 0.0,
'loss_face_giou': 1.0 if not no_interm_box_loss else 0.0,
# 2d kp
'loss_keypoints': 1.0 if not no_interm_box_loss else 0.0,
'loss_rhand_keypoints': 1.0 if not no_interm_box_loss else 0.0,
'loss_lhand_keypoints': 1.0 if not no_interm_box_loss else 0.0,
'loss_face_keypoints': 1.0 if not no_interm_box_loss else 0.0,
# 2d oks
'loss_oks': 1.0 if not no_interm_box_loss else 0.0,
'loss_rhand_oks': 1.0 if not no_interm_box_loss else 0.0,
'loss_lhand_oks': 1.0 if not no_interm_box_loss else 0.0,
'loss_face_oks': 1.0 if not no_interm_box_loss else 0.0,
# smpl param
'loss_smpl_pose_root': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_pose_body': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_pose_lhand': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_pose_rhand': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_pose_jaw': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_beta': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_expr': 1.0 if not no_interm_box_loss else 0.0,
# smpl kp3d ra
'loss_smpl_body_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_lhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_rhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_face_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0,
# smpl kp3d
'loss_smpl_body_kp3d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_face_kp3d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_lhand_kp3d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_rhand_kp3d': 1.0 if not no_interm_box_loss else 0.0,
# smpl kp2d
'loss_smpl_body_kp2d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_lhand_kp2d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_rhand_kp2d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_face_kp2d': 1.0 if not no_interm_box_loss else 0.0,
# smpl kp2d ba
'loss_smpl_body_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_lhand_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_rhand_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_face_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0,
}
try:
interm_loss_coef = args.interm_loss_coef # 1
except:
interm_loss_coef = 1.0
interm_weight_dict.update({
k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k]
for k, v in clean_weight_dict_wo_dn.items() if 'keypoints' not in k
})
weight_dict.update(interm_weight_dict)
interm_weight_dict.update({
k + f'_query_expand': v * interm_loss_coef * _coeff_weight_dict[k]
for k, v in clean_weight_dict_wo_dn.items()
}) # ???
weight_dict.update(interm_weight_dict)
losses = cfg.losses
if args.dn_number > 0:
losses += ['dn_label', 'dn_bbox']
losses += ['matching']
criterion = SetCriterion(
num_classes,
matcher=matcher,
weight_dict=weight_dict,
focal_alpha=args.focal_alpha,
losses=losses,
num_box_decoder_layers=args.num_box_decoder_layers,
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers,
num_body_points=args.num_body_points,
num_hand_points=args.num_hand_points,
num_face_points=args.num_face_points,
)
criterion.to(device)
if args.inference:
postprocessors = {
'bbox':
PostProcess_SMPLX_Multi_Infer(
num_select=args.num_select,
nms_iou_threshold=args.nms_iou_threshold,
num_body_points=args.num_body_points),
}
else:
postprocessors = {
'bbox':
PostProcess_SMPLX(
num_select=args.num_select,
nms_iou_threshold=args.nms_iou_threshold,
num_body_points=args.num_body_points),
}
postprocessors_aios = {
'bbox':
PostProcess_aios(num_select=args.num_select,
nms_iou_threshold=args.nms_iou_threshold,
num_body_points=args.num_body_points),
}
# criterion_smpl=build_architecture(cfg['smpl_loss'])
return model, criterion, postprocessors, postprocessors_aios
class AiOSSMPLX_Box(nn.Module):
def __init__(
self,
backbone,
transformer,
num_classes,
num_queries,
aux_loss=False,
iter_update=True,
query_dim=4,
random_refpoints_xy=False,
fix_refpoints_hw=-1,
num_feature_levels=1,
nheads=8,
two_stage_type='no',
dec_pred_class_embed_share=False,
dec_pred_bbox_embed_share=False,
dec_pred_pose_embed_share=False,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_batch_gt_fuse=False,
dn_labelbook_size=100,
dn_attn_mask_type_list=['group2group'],
cls_no_bias=False,
num_group=100,
num_body_points=0,
num_hand_points=0,
num_face_points=0,
num_box_decoder_layers=2,
num_hand_face_decoder_layers=4,
body_model=dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
keypoint_dst='smplx_137',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True),
train=True,
inference=False,
focal_length=[5000., 5000.],
camera_3d_size=2.5
):
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.num_classes = num_classes
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim)
self.num_body_points = num_body_points
self.num_hand_points = num_hand_points
self.num_face_points = num_face_points
self.num_whole_body_points = num_body_points + 2*num_hand_points + num_face_points
self.num_box_decoder_layers = num_box_decoder_layers
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers
self.focal_length = focal_length
self.camera_3d_size=camera_3d_size
self.inference = inference
if train:
self.smpl_convention = 'smplx'
else:
self.smpl_convention = 'h36m'
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
self.random_refpoints_xy = random_refpoints_xy # False
self.fix_refpoints_hw = fix_refpoints_hw # -1
# for dn training
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_batch_gt_fuse = dn_batch_gt_fuse
self.dn_labelbook_size = dn_labelbook_size
self.dn_attn_mask_type_list = dn_attn_mask_type_list
assert all([
i in ['match2dn', 'dn2dn', 'group2group']
for i in dn_attn_mask_type_list
])
assert not dn_batch_gt_fuse
# build human body
# if train:
# self.body_model = build_body_model(body_model)
if inference:
body_model=dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
num_betas=10,
keypoint_dst='smplx',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True)
self.body_model = build_body_model(body_model)
for param in self.body_model.parameters():
param.requires_grad = False
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels) # 3
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
))
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels,
hidden_dim,
kernel_size=3,
stride=2,
padding=1),
nn.GroupNorm(32, hidden_dim),
))
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!'
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1],
hidden_dim,
kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
])
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = box_pred_damping = None
self.iter_update = iter_update
assert iter_update, 'Why not iter_update?'
# prepare pred layers
self.dec_pred_class_embed_share = dec_pred_class_embed_share # false
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # false
# 1.1 prepare class & box embed
_class_embed = nn.Linear(hidden_dim,
num_classes,
bias=(not cls_no_bias))
if not cls_no_bias:
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
_class_embed.bias.data = torch.ones(self.num_classes) * bias_value
# 1.2 box embed layer list
if dec_pred_class_embed_share:
class_embed_layerlist = [
_class_embed for i in range(transformer.num_decoder_layers)
]
else:
class_embed_layerlist = [
copy.deepcopy(_class_embed)
for i in range(transformer.num_decoder_layers)
]
###########################################################################
# body bbox + l/r hand box + face box
###########################################################################
# 1.1 body bbox embed
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
# 1.2 body bbox embed layer list
self.num_group = num_group
if dec_pred_bbox_embed_share:
box_body_embed_layerlist = [
_bbox_embed for i in range(transformer.num_decoder_layers)
]
else:
box_body_embed_layerlist = [
copy.deepcopy(_bbox_embed)
for i in range(transformer.num_decoder_layers)
]
# 2.1 lhand bbox embed
_bbox_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) # TODO: the out shape should be 2 not 4
nn.init.constant_(_bbox_hand_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_hand_embed.layers[-1].bias.data, 0)
_bbox_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_bbox_hand_hw_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_hand_hw_embed.layers[-1].bias.data, 0)
# 2.2 lhand bbox embed layer list
if dec_pred_pose_embed_share:
box_hand_embed_layerlist = \
[_bbox_hand_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)]
else:
box_hand_embed_layerlist = [
copy.deepcopy(_bbox_hand_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers + 1)
]
if dec_pred_pose_embed_share:
box_hand_hw_embed_layerlist = [
_bbox_hand_hw_embed for i in range(
transformer.num_decoder_layers - num_box_decoder_layers)
]
else:
box_hand_hw_embed_layerlist = [
copy.deepcopy(_bbox_hand_hw_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 4.1 face bbox embed
_bbox_face_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_bbox_face_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_face_embed.layers[-1].bias.data, 0)
_bbox_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3)
nn.init.constant_(_bbox_face_hw_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_face_hw_embed.layers[-1].bias.data, 0)
# 4.2 face bbox embed layer list
if dec_pred_pose_embed_share:
box_face_embed_layerlist = [
_bbox_face_embed for i in range(
transformer.num_decoder_layers - num_box_decoder_layers + 1)
]
else:
box_face_embed_layerlist = [
copy.deepcopy(_bbox_face_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers + 1)
]
if dec_pred_pose_embed_share:
box_face_hw_embed_layerlist = [
_bbox_face_hw_embed for i in range(
transformer.num_decoder_layers - num_box_decoder_layers)]
else:
box_face_hw_embed_layerlist = [
copy.deepcopy(_bbox_face_hw_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 1. smpl pose embed
if body_model['type'].upper()=='SMPL':
self.body_model_joint_num = 24
elif body_model['type'].upper()=='SMPLX':
self.body_model_joint_num = 22
else:
raise ValueError(
f'Only supports SMPL or SMPLX, but get {body_model.type}')
#TODO:
_smpl_pose_embed = MLP(hidden_dim * 4, hidden_dim, self.body_model_joint_num * 6, 3)
nn.init.constant_(_smpl_pose_embed.layers[-1].weight.data, 0)
nn.init.constant_(_smpl_pose_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smpl_pose_embed_layerlist = [
_smpl_pose_embed
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
smpl_pose_embed_layerlist = [
copy.deepcopy(_smpl_pose_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 2. smpl betas embed
_smpl_beta_embed = MLP(hidden_dim * 4, hidden_dim, 10, 3)
nn.init.constant_(_smpl_beta_embed.layers[-1].weight.data, 0)
nn.init.constant_(_smpl_beta_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smpl_beta_embed_layerlist = [
_smpl_beta_embed
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
smpl_beta_embed_layerlist = [
copy.deepcopy(_smpl_beta_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 3. smpl cam embed
_cam_embed = MLP(hidden_dim * 4, hidden_dim, 3, 3)
nn.init.constant_(_cam_embed.layers[-1].weight.data, 0)
nn.init.constant_(_cam_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
cam_embed_layerlist = [
_cam_embed for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
cam_embed_layerlist = [
copy.deepcopy(_cam_embed)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
###########################################################################
# smplx body pose + hand pose + expression + betas + kp2d + kp3d + cam
###########################################################################
# 2. smplx hand pose embed
_smplx_hand_pose_embed_layer_2_3 = \
MLP(hidden_dim * 2, hidden_dim, 15 * 6, 3)
nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].bias.data, 0)
_smplx_hand_pose_embed_layer_4_5 = \
MLP(hidden_dim * 2, hidden_dim, 15 * 6, 3)
nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smplx_hand_pose_embed_layerlist = [
_smplx_hand_pose_embed_layer_2_3
if i<2 else _smplx_hand_pose_embed_layer_4_5
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
smplx_hand_pose_embed_layerlist = [
copy.deepcopy(_smplx_hand_pose_embed_layer_2_3)
if i<2 else copy.deepcopy(_smplx_hand_pose_embed_layer_4_5)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 3. smplx face expression
_smplx_expression_embed_layer_2_3 = \
MLP(hidden_dim*2, hidden_dim, 10, 3)
nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].bias.data, 0)
_smplx_expression_embed_layer_4_5 = \
MLP(hidden_dim * 2, hidden_dim, 10, 3)
nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smplx_expression_embed_layerlist = [
_smplx_expression_embed_layer_2_3
if i<2 else _smplx_expression_embed_layer_4_5
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
else:
smplx_expression_embed_layerlist = [
copy.deepcopy(_smplx_expression_embed_layer_2_3)
if i<2 else copy.deepcopy(_smplx_expression_embed_layer_4_5)
for i in range(transformer.num_decoder_layers -
num_box_decoder_layers)
]
# 4. smplx jaw pose embed
_smplx_jaw_embed_2_3 = MLP(hidden_dim * 2, hidden_dim, 6, 3)
nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].bias.data, 0)
_smplx_jaw_embed_4_5 = MLP(hidden_dim * 2, hidden_dim, 6, 3)
nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].weight.data, 0)
nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
smplx_jaw_embed_layerlist = [
_smplx_jaw_embed_2_3 if i<2 else _smplx_jaw_embed_4_5
for i in range(
transformer.num_decoder_layers - num_box_decoder_layers)
]
else:
smplx_jaw_embed_layerlist = [
copy.deepcopy(_smplx_jaw_embed_2_3)
if i<2 else copy.deepcopy(_smplx_jaw_embed_4_5)
for i in range(
transformer.num_decoder_layers - num_box_decoder_layers)
]
self.bbox_embed = nn.ModuleList(box_body_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
# smpl
self.smpl_pose_embed = nn.ModuleList(smpl_pose_embed_layerlist)
self.smpl_beta_embed = nn.ModuleList(smpl_beta_embed_layerlist)
self.smpl_cam_embed = nn.ModuleList(cam_embed_layerlist)
# smplx lhand kp
self.bbox_hand_embed = nn.ModuleList(box_hand_embed_layerlist)
self.bbox_hand_hw_embed = nn.ModuleList(box_hand_hw_embed_layerlist)
self.transformer.decoder.bbox_hand_embed = self.bbox_hand_embed
self.transformer.decoder.bbox_hand_hw_embed = self.bbox_hand_hw_embed
# smplx face kp
self.bbox_face_embed = nn.ModuleList(box_face_embed_layerlist)
self.bbox_face_hw_embed = nn.ModuleList(box_face_hw_embed_layerlist)
self.transformer.decoder.bbox_face_embed = self.bbox_face_embed
self.transformer.decoder.bbox_face_hw_embed = self.bbox_face_hw_embed
# smplx
self.smpl_hand_pose_embed = nn.ModuleList(smplx_hand_pose_embed_layerlist)
self.smpl_expr_embed = nn.ModuleList(smplx_expression_embed_layerlist)
self.smpl_jaw_embed = nn.ModuleList(smplx_jaw_embed_layerlist)
self.transformer.decoder.num_hand_face_decoder_layers = num_hand_face_decoder_layers
self.transformer.decoder.num_box_decoder_layers = num_box_decoder_layers
self.transformer.decoder.num_body_points = num_body_points
self.transformer.decoder.num_hand_points = num_hand_points
self.transformer.decoder.num_face_points = num_face_points
# two stage
self.two_stage_type = two_stage_type
assert two_stage_type in [
'no', 'standard'
], 'unknown param {} of two_stage_type'.format(two_stage_type)
if two_stage_type != 'no':
if two_stage_bbox_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(
_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(
_class_embed)
self.refpoint_embed = None
self._reset_parameters()
def get_camera_trans(self, cam_param, input_body_shape):
# camera translation
t_xy = cam_param[:, :2]
gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive
k_value = torch.FloatTensor(
[
math.sqrt(
self.focal_length[0] * self.focal_length[1] * self.camera_3d_size * self.camera_3d_size /
(input_body_shape[0] * input_body_shape[1])
)
]
).cuda().view(-1)
t_z = k_value * gamma
cam_trans = torch.cat((t_xy, t_z[:, None]), 1)
return cam_trans
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def prepare_for_dn2(self, targets):
if not self.training:
device = targets[0]['boxes'].device
bs = len(targets)
num_points = 4
attn_mask2 = torch.zeros(
bs,
self.nheads,
self.num_group * 4,
self.num_group * 4,
device=device,
dtype=torch.bool)
group_bbox_kpt = 4
# body bbox index
kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]]
for matchj in range(self.num_group * 4):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
# for each instance, they should associate with their query (body hand face)
if sj > 0:
attn_mask2[:, :, matchj, :sj] = True
if ej < self.num_group * 4:
attn_mask2[:, :, matchj, ej:] = True
for match_x in range(self.num_group * 4):
if match_x % group_bbox_kpt in [0, 1, 2, 3]:
# each query (hand face body) should associate with all body query
attn_mask2[:,:,match_x, kpt_index]=False
num_points = 4
attn_mask3 = torch.zeros(
bs,
self.nheads,
self.num_group * 4,
self.num_group * 4,
device=device,
dtype=torch.bool)
group_bbox_kpt = 4
kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]]
for matchj in range(self.num_group * 4):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
# for each instance, they should associate with their query (body hand face)
if sj > 0:
attn_mask3[:, :, matchj, :sj] = True
if ej < self.num_group * 4:
attn_mask3[:, :, matchj, ej:] = True
for match_x in range(self.num_group * 4):
if match_x % group_bbox_kpt in [0, 1, 2, 3]:
# each query (hand face body) should associate with all body query
attn_mask3[:, :, match_x, kpt_index] = False
attn_mask2 = attn_mask2.flatten(0, 1)
attn_mask3 = attn_mask3.flatten(0, 1)
return None, None, None, attn_mask2, attn_mask3, None
# targets, dn_scalar, noise_scale = dn_args
device = targets[0]['boxes'].device
bs = len(targets)
dn_number = self.dn_number # 100
dn_box_noise_scale = self.dn_box_noise_scale # 0.4
dn_label_noise_ratio = self.dn_label_noise_ratio # 0.5
# gather gt boxes and labels
gt_boxes = [t['boxes'] for t in targets]
gt_labels = [t['labels'] for t in targets]
gt_keypoints = [t['keypoints'] for t in targets]
# repeat them
def get_indices_for_repeat(now_num, target_num, device='cuda'):
"""
Input:
- now_num: int
- target_num: int
Output:
- indices: tensor[target_num]
"""
out_indice = []
base_indice = torch.arange(now_num).to(device)
multiplier = target_num // now_num
out_indice.append(base_indice.repeat(multiplier))
residue = target_num % now_num
out_indice.append(base_indice[torch.randint(0,
now_num, (residue, ),
device=device)])
return torch.cat(out_indice)
if self.dn_batch_gt_fuse:
raise NotImplementedError
gt_boxes_bsall = torch.cat(gt_boxes) # num_boxes, 4
gt_labels_bsall = torch.cat(gt_labels)
num_gt_bsall = gt_boxes_bsall.shape[0]
if num_gt_bsall > 0:
indices = get_indices_for_repeat(num_gt_bsall, dn_number,
device)
gt_boxes_expand = gt_boxes_bsall[indices][None].repeat(
bs, 1, 1) # bs, num_dn, 4
gt_labels_expand = gt_labels_bsall[indices][None].repeat(
bs, 1) # bs, num_dn
else:
# all negative samples when no gt boxes
gt_boxes_expand = torch.rand(bs, dn_number, 4, device=device)
gt_labels_expand = torch.ones(
bs, dn_number, dtype=torch.int64, device=device) * int(
self.num_classes)
else:
gt_boxes_expand = []
gt_labels_expand = []
gt_keypoints_expand = [] # here
for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate(
zip(gt_boxes, gt_labels, gt_keypoints)): # idx -> batch id
num_gt_i = gt_boxes_i.shape[0] # instance num
if num_gt_i > 0:
indices = get_indices_for_repeat(num_gt_i, dn_number,
device)
gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4
gt_labels_expand_i = gt_labels_i[indices] # add smpl
gt_keypoints_expand_i = gt_keypoint_i[indices]
else:
# all negative samples when no gt boxes
gt_boxes_expand_i = torch.rand(dn_number, 4, device=device)
gt_labels_expand_i = torch.ones(
dn_number, dtype=torch.int64, device=device) * int(
self.num_classes)
gt_keypoints_expand_i = torch.rand(dn_number,
self.num_body_points *
3,
device=device)
gt_boxes_expand.append(gt_boxes_expand_i) # add smpl
gt_labels_expand.append(gt_labels_expand_i)
gt_keypoints_expand.append(gt_keypoints_expand_i)
gt_boxes_expand = torch.stack(gt_boxes_expand)
gt_labels_expand = torch.stack(gt_labels_expand)
gt_keypoints_expand = torch.stack(gt_keypoints_expand)
knwon_boxes_expand = gt_boxes_expand.clone()
knwon_labels_expand = gt_labels_expand.clone()
# add noise
if dn_label_noise_ratio > 0:
prob = torch.rand_like(knwon_labels_expand.float())
chosen_indice = prob < dn_label_noise_ratio
new_label = torch.randint_like(
knwon_labels_expand[chosen_indice], 0,
self.dn_labelbook_size) # randomly put a new one here
knwon_labels_expand[chosen_indice] = new_label
if dn_box_noise_scale > 0:
diff = torch.zeros_like(knwon_boxes_expand)
diff[..., :2] = knwon_boxes_expand[..., 2:] / 2
diff[..., 2:] = knwon_boxes_expand[..., 2:]
knwon_boxes_expand += torch.mul(
(torch.rand_like(knwon_boxes_expand) * 2 - 1.0),
diff) * dn_box_noise_scale
knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0)
input_query_label = self.label_enc(knwon_labels_expand)
input_query_bbox = inverse_sigmoid(knwon_boxes_expand)
# prepare mask
if 'group2group' in self.dn_attn_mask_type_list:
attn_mask = torch.zeros(bs,
self.nheads,
dn_number + self.num_queries,
dn_number + self.num_queries,
device=device,
dtype=torch.bool)
attn_mask[:, :, dn_number:, :dn_number] = True
for idx, (gt_boxes_i, gt_labels_i) in enumerate(
zip(gt_boxes, gt_labels)): # for batch
num_gt_i = gt_boxes_i.shape[0]
if num_gt_i == 0:
continue
for matchi in range(dn_number):
si = (matchi // num_gt_i) * num_gt_i
ei = (matchi // num_gt_i + 1) * num_gt_i
if si > 0:
attn_mask[idx, :, matchi, :si] = True
if ei < dn_number:
attn_mask[idx, :, matchi, ei:dn_number] = True
attn_mask = attn_mask.flatten(0, 1)
if 'group2group' in self.dn_attn_mask_type_list:
# self.num_body_points = self.num_body_points +3
num_points = 4
attn_mask2 = torch.zeros(
bs,
self.nheads,
dn_number + self.num_group * 4,
dn_number + self.num_group * 4,
device=device,
dtype=torch.bool)
attn_mask2[:, :, dn_number:, :dn_number] = True
group_bbox_kpt = 4
for matchj in range(self.num_group * 4):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
# for each instance, they should associate their body, hand, and face bbox
if sj > 0:
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True
if ej < self.num_group * 4:
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True
# body bbox index
kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]]
for match_x in range(self.num_group * 4):
if match_x % group_bbox_kpt in [0, 1, 2, 3]:
# for each instance, they should associate their each query with
# other instances' body query
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False
for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)):
num_gt_i = gt_boxes_i.shape[0]
if num_gt_i == 0:
continue
for matchi in range(dn_number):
si = (matchi // num_gt_i) * num_gt_i
ei = (matchi // num_gt_i + 1) * num_gt_i
if si > 0:
attn_mask2[idx, :, matchi, :si] = True
if ei < dn_number:
attn_mask2[idx, :, matchi, ei:dn_number] = True
attn_mask2 = attn_mask2.flatten(0, 1)
if 'group2group' in self.dn_attn_mask_type_list:
num_points = 4
attn_mask3 = torch.zeros(
bs,
self.nheads,
dn_number + self.num_group * 4, dn_number + self.num_group * 4,
device=device, dtype=torch.bool)
attn_mask3[:, :, dn_number:, :dn_number] = True
group_bbox_kpt = 4
for matchj in range(self.num_group * 4):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
# for each instance, they should associate their body, hand, and face bbox
if sj > 0:
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True
if ej < self.num_group * 4:
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True
kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]]
for match_x in range(self.num_group * 4):
if match_x % group_bbox_kpt in [0, 1, 2, 3]:
# for each instance, they should associate their each query with
# other instances' body query
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False
for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)):
num_gt_i = gt_boxes_i.shape[0]
if num_gt_i == 0:
continue
for matchi in range(dn_number):
si = (matchi // num_gt_i) * num_gt_i
ei = (matchi // num_gt_i + 1) * num_gt_i
if si > 0:
attn_mask3[idx, :, matchi, :si] = True
if ei < dn_number:
attn_mask3[idx, :, matchi, ei:dn_number] = True
attn_mask3 = attn_mask3.flatten(0, 1)
mask_dict = {
'pad_size': dn_number,
'known_bboxs': gt_boxes_expand,
'known_labels': gt_labels_expand,
'known_keypoints': gt_keypoints_expand
}
return input_query_label, input_query_bbox, attn_mask, attn_mask2, attn_mask3, mask_dict
def dn_post_process2(self, outputs_class, outputs_coord, mask_dict):
if mask_dict and mask_dict['pad_size'] > 0:
output_known_class = [
outputs_class_i[:, :mask_dict['pad_size'], :]
for outputs_class_i in outputs_class
]
output_known_coord = [
outputs_coord_i[:, :mask_dict['pad_size'], :]
for outputs_coord_i in outputs_coord
]
outputs_class = [
outputs_class_i[:, mask_dict['pad_size']:, :]
for outputs_class_i in outputs_class
]
outputs_coord = [
outputs_coord_i[:, mask_dict['pad_size']:, :]
for outputs_coord_i in outputs_coord
]
mask_dict.update({
'output_known_coord': output_known_coord,
'output_known_class': output_known_class
})
return outputs_class, outputs_coord
def forward(self, data_batch: NestedTensor, targets: List = None):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if isinstance(data_batch, dict):
samples, targets = self.prepare_targets(data_batch)
# import pdb; pdb.set_trace()
elif isinstance(data_batch, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(data_batch)
else:
samples = data_batch
features, poss = self.backbone(samples)
srcs = []
masks = []
for l, feat in enumerate(features): # len(features=3)
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(),
size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
if self.dn_number > 0 or targets is not None:
input_query_label, input_query_bbox, attn_mask,attn_mask2, attn_mask3, mask_dict =\
self.prepare_for_dn2(targets)
else:
assert targets is None
input_query_bbox = input_query_label = attn_mask = attn_mask2 = attn_mask3 = mask_dict = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask,
attn_mask2, attn_mask3)
# update human boxes
effective_dn_number = self.dn_number if self.training else 0
outputs_body_bbox_list = []
outputs_class = []
for dec_lid, (layer_ref_sig, layer_body_bbox_embed, layer_cls_embed,
layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed,
self.class_embed, hs)):
if dec_lid < self.num_box_decoder_layers:
# human det
layer_delta_unsig = layer_body_bbox_embed(layer_hs)
layer_body_box_outputs_unsig = \
layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_body_box_outputs_unsig = layer_body_box_outputs_unsig.sigmoid()
layer_cls = layer_cls_embed(layer_hs)
outputs_body_bbox_list.append(layer_body_box_outputs_unsig)
outputs_class.append(layer_cls)
elif dec_lid < self.num_box_decoder_layers + 2:
bs = layer_ref_sig.shape[0]
# dn body bbox
layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query
reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query
layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn)
layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid(
reference_before_sigmoid_body_bbox_dn)
layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid()
# norm body bbox
layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][
:, 0::(self.num_body_points + 4), :] # norm content query
reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][
:, 0::(self.num_body_points+ 4), :] # norm position query
layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm)
layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid(
reference_before_sigmoid_body_bbox_norm)
layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid()
layer_body_box_outputs_unsig = torch.cat(
(layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1)
# classfication
layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn)
layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm)
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1)
outputs_class.append(layer_cls)
outputs_body_bbox_list.append(layer_body_box_outputs_unsig)
else:
bs = layer_ref_sig.shape[0]
# dn body bbox
layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query
reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query
layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn)
layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid(
reference_before_sigmoid_body_bbox_dn)
layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid()
# norm body bbox
layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][
:, 0::(self.num_whole_body_points + 4), :] # norm content query
reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:,effective_dn_number:, :][
:, 0::(self.num_whole_body_points + 4), :] # norm position query
layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm)
layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid(
reference_before_sigmoid_body_bbox_norm)
layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid()
layer_body_box_outputs_unsig = torch.cat(
(layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1)
# classfication
layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn)
layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm)
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1)
outputs_class.append(layer_cls)
outputs_body_bbox_list.append(layer_body_box_outputs_unsig)
# update hand and face boxes
outputs_lhand_bbox_list = []
outputs_rhand_bbox_list = []
outputs_face_bbox_list = []
# update keypoints boxes
outputs_body_keypoints_list = []
outputs_body_keypoints_hw = []
outputs_lhand_keypoints_list = []
outputs_lhand_keypoints_hw = []
outputs_rhand_keypoints_list = []
outputs_rhand_keypoints_hw = []
outputs_face_keypoints_list = []
outputs_face_keypoints_hw = []
outputs_smpl_pose_list = []
outputs_smpl_lhand_pose_list = []
outputs_smpl_rhand_pose_list = []
outputs_smpl_expr_list = []
outputs_smpl_jaw_pose_list = []
outputs_smpl_beta_list = []
outputs_smpl_cam_list = []
outputs_smpl_kp2d_list = []
outputs_smpl_kp3d_list = []
outputs_smpl_verts_list = []
# smpl pose
# body box, kps, lhand box
body_index = [0, 1, 2, 3]
smpl_pose_index = [
x for x in range(self.num_group * 4) if (x % 4 in body_index)]
# smpl lhand
lhand_index = [0, 1]
smpl_lhand_pose_index = [
x for x in range(self.num_group * 4) if (x % 4 in lhand_index)]
# smpl rhand
rhand_index = [0, 2]
smpl_rhand_pose_index = [
x for x in range(self.num_group * 4) if (x % 4 in rhand_index)]
# smpl face
face_index = [0, 3]
smpl_face_pose_index = [
x for x in range(self.num_group * 4) if (x % 4 in face_index)]
for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)):
if dec_lid < self.num_box_decoder_layers:
assert isinstance(layer_hs, torch.Tensor)
bs = layer_hs.shape[0]
layer_body_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_body_points * 3)) # [-, 900, 42]
outputs_body_keypoints_list.append(layer_body_kps_res)
# lhand
layer_lhand_bbox_res = layer_hs.new_zeros(
(bs, self.num_queries, 4)) # [-, 900, 42]
outputs_lhand_bbox_list.append(layer_lhand_bbox_res)
layer_lhand_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_hand_points * 3)) # [-, 900, 42]
outputs_lhand_keypoints_list.append(layer_lhand_kps_res)
# rhand
layer_rhand_bbox_res = layer_hs.new_zeros(
(bs, self.num_queries, 4)) # [-, 900, 42]
outputs_rhand_bbox_list.append(layer_rhand_bbox_res)
layer_rhand_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_hand_points * 3)) # [-, 900, 42]
outputs_rhand_keypoints_list.append(layer_rhand_kps_res)
# face
layer_face_bbox_res = layer_hs.new_zeros(
(bs, self.num_queries, 4)) # [-, 900, 42]
outputs_face_bbox_list.append(layer_face_bbox_res)
layer_face_kps_res = layer_hs.new_zeros(
(bs, self.num_queries,
self.num_face_points * 3)) # [-, 900, 42]
outputs_face_keypoints_list.append(layer_face_kps_res)
# smpl or smplx
smpl_pose = layer_hs.new_zeros((bs, self.num_queries, self.body_model_joint_num * 3))
smpl_rhand_pose = layer_hs.new_zeros(
(bs, self.num_queries, 15 * 3))
smpl_lhand_pose = layer_hs.new_zeros(
(bs, self.num_queries, 15 * 3))
smpl_expr = layer_hs.new_zeros((bs, self.num_queries, 10))
smpl_jaw_pose = layer_hs.new_zeros((bs, self.num_queries, 6))
smpl_beta = layer_hs.new_zeros((bs, self.num_queries, 10))
smpl_cam = layer_hs.new_zeros((bs, self.num_queries, 3))
# smpl_kp2d = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points,3))
smpl_kp3d = layer_hs.new_zeros(
(bs, self.num_queries, self.num_body_points, 4))
outputs_smpl_pose_list.append(smpl_pose)
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose)
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose)
outputs_smpl_expr_list.append(smpl_expr)
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose)
outputs_smpl_beta_list.append(smpl_beta)
outputs_smpl_cam_list.append(smpl_cam)
# outputs_smpl_kp2d_list.append(smpl_kp2d)
outputs_smpl_kp3d_list.append(smpl_kp3d)
elif dec_lid < self.num_box_decoder_layers +2:
bs = layer_ref_sig.shape[0]
# lhand bbox
layer_hs_lhand_bbox = \
layer_hs[:, effective_dn_number:, :][:, 1::4, :]
delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox)
layer_ref_sig_lhand_bbox = \
layer_ref_sig[:,effective_dn_number:, :][:, 1::4, :].clone()
layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox)
delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox)
layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2]
layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig
layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid()
outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox)
# rhand bbox
layer_hs_rhand_bbox = \
layer_hs[:, effective_dn_number:, :][:, 2::4, :]
delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox)
layer_ref_sig_rhand_bbox = \
layer_ref_sig[:,effective_dn_number:, :][:, 2::4, :].clone()
layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox)
delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox)
layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2]
layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig
layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid()
outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox)
# face bbox
layer_hs_face_bbox = \
layer_hs[:, effective_dn_number:, :][:, 3::4, :]
delta_face_bbox_xy_unsig = self.bbox_face_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox)
layer_ref_sig_face_bbox = \
layer_ref_sig[:,effective_dn_number:, :][:, 3::4, :].clone()
layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox)
delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox)
layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2]
layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig
layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid()
outputs_face_bbox_list.append(layer_ref_sig_face_bbox)
# smpl or smplx
bs, _, feat_dim = layer_hs.shape
smpl_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * 4)
smpl_lhand_pose_feats = \
layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * 2)
smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * 2)
smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_face_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * 2)
smpl_pose = self.smpl_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_feats)
smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape(
bs, self.num_group, self.body_model_joint_num, 3, 3)
smpl_lhand_pose = self.smpl_hand_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats)
smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape(
-1, 6)).reshape(bs, self.num_group, 15, 3, 3)
smpl_rhand_pose = self.smpl_hand_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats)
smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape(
-1, 6)).reshape(bs, self.num_group, 15, 3, 3)
smpl_jaw_pose = self.smpl_jaw_embed[
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats)
smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape(
bs, self.num_group, 1, 3, 3)
smpl_beta = self.smpl_beta_embed[
dec_lid - self.num_box_decoder_layers](smpl_feats)
smpl_cam = self.smpl_cam_embed[
dec_lid - self.num_box_decoder_layers](smpl_feats)
smpl_expr = self.smpl_expr_embed[
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats)
# smpl_jaw_pose = layer_hs.new_zeros(bs, self.num_group, 3)
leye_pose = torch.zeros_like(smpl_jaw_pose)
reye_pose = torch.zeros_like(smpl_jaw_pose)
if self.body_model is not None:
smpl_pose_ = rotmat_to_aa(smpl_pose)
# smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose)
# smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose)
smpl_lhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3)
smpl_rhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3)
smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose)
leye_pose_ = rotmat_to_aa(leye_pose)
reye_pose_ = rotmat_to_aa(reye_pose)
pred_output = self.body_model(
betas=smpl_beta.reshape(-1, 10),
body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3),
global_orient=smpl_pose_[:, :, 0].reshape(
-1, 3).unsqueeze(1),
left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3),
right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3),
leye_pose=leye_pose_,
reye_pose=reye_pose_,
jaw_pose=smpl_jaw_pose_.reshape(-1, 3),
# expression=smpl_expr.reshape(-1, 10),
expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10)
)
smpl_kp3d = pred_output['joints'].reshape(
bs, self.num_group, -1, 3)
smpl_verts = pred_output['vertices'].reshape(
bs, self.num_group, -1, 3)
# pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3)
outputs_smpl_pose_list.append(smpl_pose)
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose)
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose)
outputs_smpl_expr_list.append(smpl_expr)
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose)
outputs_smpl_beta_list.append(smpl_beta)
outputs_smpl_cam_list.append(smpl_cam)
outputs_smpl_kp3d_list.append(smpl_kp3d)
else:
bs = layer_ref_sig.shape[0]
# lhand bbox
layer_hs_lhand_bbox = \
layer_hs[:, effective_dn_number:, :][:, 1::4, :]
delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox)
layer_ref_sig_lhand_bbox = \
layer_ref_sig[:,effective_dn_number:, :][:, 1::4, :].clone()
layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox)
delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox)
layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2]
layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig
layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid()
outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox)
# rhand bbox
layer_hs_rhand_bbox = \
layer_hs[:, effective_dn_number:, :][:, 2::4, :]
delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[
dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox)
layer_ref_sig_rhand_bbox = \
layer_ref_sig[:,effective_dn_number:, :][:, 2::4, :].clone()
layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox)
delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox)
layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2]
layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig
layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid()
outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox)
# face bbox
layer_hs_face_bbox = \
layer_hs[:, effective_dn_number:, :][:, 3::4, :]
delta_face_bbox_xy_unsig = \
self.bbox_face_embed[dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox)
layer_ref_sig_face_bbox = \
layer_ref_sig[:,effective_dn_number:, :][:, 3::4, :].clone()
layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox)
delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[
dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox)
layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2]
layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig
layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid()
outputs_face_bbox_list.append(layer_ref_sig_face_bbox)
bs, _, feat_dim = layer_hs.shape
smpl_body_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * 4)
smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * 2)
smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * 2)
smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select(
1, torch.tensor(smpl_face_pose_index, device=layer_hs.device)
).reshape(bs, -1, feat_dim * 2)
smpl_pose = self.smpl_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats)
smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape(
bs, self.num_group, self.body_model_joint_num, 3, 3)
smpl_lhand_pose = self.smpl_hand_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats)
smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape(
-1, 6)).reshape(bs, self.num_group, 15, 3, 3)
smpl_rhand_pose = self.smpl_hand_pose_embed[
dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats)
smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape(
-1, 6)).reshape(bs, self.num_group, 15, 3, 3)
smpl_expr = self.smpl_expr_embed[
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats)
smpl_jaw_pose = self.smpl_jaw_embed[
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats)
smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape(
bs, self.num_group, 1, 3, 3)
smpl_beta = self.smpl_beta_embed[
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats)
smpl_cam = self.smpl_cam_embed[
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats)
num_samples = smpl_beta.reshape(-1, 10).shape[0]
device = smpl_beta.device
leye_pose = torch.zeros_like(smpl_jaw_pose)
reye_pose = torch.zeros_like(smpl_jaw_pose)
if self.body_model is not None:
smpl_pose_ = rotmat_to_aa(smpl_pose)
smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose)
smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose)
smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose)
leye_pose_ = rotmat_to_aa(leye_pose)
reye_pose_ = rotmat_to_aa(reye_pose)
pred_output = self.body_model(
betas=smpl_beta.reshape(-1, 10),
body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3),
global_orient=smpl_pose_[:, :, 0].reshape(
-1, 3).unsqueeze(1),
left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3),
right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3),
leye_pose=leye_pose_,
reye_pose=reye_pose_,
jaw_pose=smpl_jaw_pose_.reshape(-1, 3),
expression=smpl_expr.reshape(-1, 10),
# expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10),
)
smpl_kp3d = pred_output['joints'].reshape(
bs, self.num_group, -1, 3)
smpl_verts = pred_output['vertices'].reshape(
bs, self.num_group, -1, 3)
outputs_smpl_pose_list.append(smpl_pose)
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose)
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose)
outputs_smpl_expr_list.append(smpl_expr)
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose)
outputs_smpl_beta_list.append(smpl_beta)
outputs_smpl_cam_list.append(smpl_cam)
outputs_smpl_kp3d_list.append(smpl_kp3d)
if not self.training:
outputs_smpl_verts_list.append(smpl_verts)
dn_mask_dict = mask_dict
if self.dn_number > 0 and dn_mask_dict is not None:
outputs_class, outputs_body_bbox_list = self.dn_post_process2(
outputs_class, outputs_body_bbox_list, dn_mask_dict)
dn_class_input = dn_mask_dict['known_labels']
dn_bbox_input = dn_mask_dict['known_bboxs']
dn_class_pred = dn_mask_dict['output_known_class']
dn_bbox_pred = dn_mask_dict['output_known_coord']
for idx, (_out_class, _out_bbox) in enumerate(zip(outputs_class, outputs_body_bbox_list)):
assert _out_class.shape[1] == _out_bbox.shape[1]
out = {
'pred_logits': outputs_class[-1],
'pred_boxes': outputs_body_bbox_list[-1],
'pred_lhand_boxes': outputs_lhand_bbox_list[-1],
'pred_rhand_boxes': outputs_rhand_bbox_list[-1],
'pred_face_boxes': outputs_face_bbox_list[-1],
'pred_smpl_pose': outputs_smpl_pose_list[-1],
'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1],
'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1],
'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1],
'pred_smpl_expr': outputs_smpl_expr_list[-1],
'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10]
'pred_smpl_cam': outputs_smpl_cam_list[-1],
'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1]
}
if not self.training:
full_pose = torch.cat((outputs_smpl_pose_list[-1],
outputs_smpl_lhand_pose_list[-1],
outputs_smpl_rhand_pose_list[-1],
outputs_smpl_jaw_pose_list[-1]),dim=2)
bs,num_q,_,_,_ = full_pose.shape
full_pose = rotmat_to_aa(full_pose).reshape(bs,num_q,53*3)
out = {
'pred_logits': outputs_class[-1],
'pred_boxes': outputs_body_bbox_list[-1],
'pred_lhand_boxes': outputs_lhand_bbox_list[-1],
'pred_rhand_boxes': outputs_rhand_bbox_list[-1],
'pred_face_boxes': outputs_face_bbox_list[-1],
'pred_smpl_pose': outputs_smpl_pose_list[-1],
'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1],
'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1],
'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1],
'pred_smpl_expr': outputs_smpl_expr_list[-1],
'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10]
'pred_smpl_cam': outputs_smpl_cam_list[-1],
'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1],
'pred_smpl_verts': outputs_smpl_verts_list[-1],
'pred_smpl_fullpose': full_pose
}
if self.dn_number > 0 and dn_mask_dict is not None:
out.update({
'dn_class_input': dn_class_input,
'dn_bbox_input': dn_bbox_input,
'dn_class_pred': dn_class_pred[-1],
'dn_bbox_pred': dn_bbox_pred[-1],
'num_tgt': dn_mask_dict['pad_size']
})
if self.aux_loss:
out['aux_outputs'] = \
self._set_aux_loss(
outputs_class,
outputs_body_bbox_list,
outputs_lhand_bbox_list,
outputs_rhand_bbox_list,
outputs_face_bbox_list,
outputs_smpl_pose_list,
outputs_smpl_rhand_pose_list,
outputs_smpl_lhand_pose_list,
outputs_smpl_jaw_pose_list,
outputs_smpl_expr_list,
outputs_smpl_beta_list,
outputs_smpl_cam_list,
outputs_smpl_kp3d_list
) # with key pred_logits, pred_bbox, pred_keypoints
if self.dn_number > 0 and dn_mask_dict is not None:
assert len(dn_class_pred[:-1]) == len(
dn_bbox_pred[:-1]) == len(out['aux_outputs'])
for aux_out, dn_class_pred_i, dn_bbox_pred_i in zip(
out['aux_outputs'], dn_class_pred, dn_bbox_pred):
aux_out.update({
'dn_class_input': dn_class_input,
'dn_bbox_input': dn_bbox_input,
'dn_class_pred': dn_class_pred_i,
'dn_bbox_pred': dn_bbox_pred_i,
'num_tgt': dn_mask_dict['pad_size']
})
# for encoder output
if hs_enc is not None:
interm_coord = ref_enc[-1]
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1])
interm_pose = torch.zeros_like(outputs_body_keypoints_list[0])
out['interm_outputs'] = {
'pred_logits': interm_class,
'pred_boxes': interm_coord,
'pred_keypoints': interm_pose
}
return out, targets, data_batch
@torch.jit.unused
def _set_aux_loss(self,
outputs_class,
outputs_body_coord,
outputs_lhand_coord,
outputs_rhand_coord,
outputs_face_coord,
outputs_smpl_pose,
outputs_smpl_rhand_pose,
outputs_smpl_lhand_pose,
outputs_smpl_jaw_pose,
outputs_smpl_expr,
outputs_smpl_beta,
outputs_smpl_cam,
outputs_smpl_kp3d):
return [{
'pred_logits': a,
'pred_boxes': b,
'pred_lhand_boxes': c,
'pred_rhand_boxes': d,
'pred_face_boxes': e,
'pred_smpl_pose': j,
'pred_smpl_rhand_pose': k,
'pred_smpl_lhand_pose': l,
'pred_smpl_jaw_pose': m,
'pred_smpl_expr': n,
'pred_smpl_beta': o,
'pred_smpl_cam': p,
'pred_smpl_kp3d': q
} for a, b, c, d, e, j, k, l, m, n, o, p, q in zip(
outputs_class[:-1],
outputs_body_coord[:-1],
outputs_lhand_coord[:-1],
outputs_rhand_coord[:-1],
outputs_face_coord[:-1],
outputs_smpl_pose[:-1],
outputs_smpl_rhand_pose[:-1],
outputs_smpl_lhand_pose[:-1],
outputs_smpl_jaw_pose[:-1],
outputs_smpl_expr[:-1],
outputs_smpl_beta[:-1],
outputs_smpl_cam[:-1],
outputs_smpl_kp3d[:-1])]
def prepare_targets(self, data_batch):
data_batch_coco = []
instance_dict = {}
img_list = data_batch['img'].float()
# input_img_h, input_img_w = data_batch['image_metas'][0]['batch_input_shape']
batch_size, _, input_img_h, input_img_w = img_list.shape
device = img_list.device
masks = torch.ones((batch_size, input_img_h, input_img_w),
dtype=torch.bool,
device=device)
if self.num_body_points == 17:
ed_convention = 'coco'
elif self.num_body_points == 14:
ed_convention = 'crowdpose'
# cv2.imread(data_batch['img_metas'][img_id]['image_path']).shape
for img_id in range(batch_size):
img_h, img_w = data_batch['img_shape'][img_id]
masks[img_id, :img_h, :img_w] = 0
if not self.inference:
instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\
data_batch['body_bbox_size'][img_id]],dim=-1)
instance_face_bbox = torch.cat([data_batch['face_bbox_center'][img_id],\
data_batch['face_bbox_size'][img_id]],dim=-1)
instance_lhand_bbox = torch.cat([data_batch['lhand_bbox_center'][img_id],\
data_batch['lhand_bbox_size'][img_id]],dim=-1)
instance_rhand_bbox = torch.cat([data_batch['rhand_bbox_center'][img_id],\
data_batch['rhand_bbox_size'][img_id]],dim=-1)
instance_kp2d = data_batch['joint_img'][img_id].clone().float()
instance_kp2d_mask = data_batch['joint_trunc'][img_id].clone().float()
instance_kp2d[:,:,2:] = instance_kp2d_mask
body_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'coco', approximate=True)
lhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_lhand', approximate=True)
rhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_rhand', approximate=True)
face_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_face', approximate=True)
# from util.vis_utils import show_bbox
# show_bbox(img_list[img_id],instance_kp2d.cpu().numpy(),data_batch['bbox_xywh'][img_id].cpu().numpy)
body_kp2d[:,:,0] = body_kp2d[:,:,0]/cfg.output_hm_shape[2]
body_kp2d[:,:,1] = body_kp2d[:,:,1]/cfg.output_hm_shape[1]
body_kp2d = torch.cat([body_kp2d[:,:,:2].flatten(1),body_kp2d[:,:,2]],dim=-1)
lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2]
lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1]
lhand_kp2d = torch.cat([lhand_kp2d[:,:,:2].flatten(1),lhand_kp2d[:,:,2]],dim=-1)
rhand_kp2d[:,:,0] = rhand_kp2d[:,:,0]/cfg.output_hm_shape[2]
rhand_kp2d[:,:,1] = rhand_kp2d[:,:,1]/cfg.output_hm_shape[1]
rhand_kp2d = torch.cat([rhand_kp2d[:,:,:2].flatten(1),rhand_kp2d[:,:,2]],dim=-1)
face_kp2d[:,:,0] = face_kp2d[:,:,0]/cfg.output_hm_shape[2]
face_kp2d[:,:,1] = face_kp2d[:,:,1]/cfg.output_hm_shape[1]
face_kp2d = torch.cat([face_kp2d[:,:,:2].flatten(1),face_kp2d[:,:,2]],dim=-1)
instance_dict = {}
instance_dict['boxes'] = instance_body_bbox.float()
instance_dict['face_boxes'] = instance_face_bbox.float()
instance_dict['lhand_boxes'] = instance_lhand_bbox.float()
instance_dict['rhand_boxes'] = instance_rhand_bbox.float()
instance_dict['keypoints'] = body_kp2d.float()
instance_dict['lhand_keypoints'] = lhand_kp2d.float()
instance_dict['rhand_keypoints'] = rhand_kp2d.float()
instance_dict['face_keypoints'] = face_kp2d.float()
# instance_dict['orig_size'] = data_batch['ori_shape'][img_id]
instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation
instance_dict['area'] = instance_body_bbox[:, 2] * instance_body_bbox[:, 3]
instance_dict['lhand_area'] = instance_lhand_bbox[:, 2] * instance_lhand_bbox[:, 3]
instance_dict['rhand_area'] = instance_rhand_bbox[:, 2] * instance_rhand_bbox[:, 3]
instance_dict['face_area'] = instance_face_bbox[:, 2] * instance_face_bbox[:, 3]
instance_dict['labels'] = torch.ones(instance_body_bbox.shape[0],
dtype=torch.long,
device=device)
data_batch_coco.append(instance_dict)
else:
instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\
data_batch['body_bbox_size'][img_id]],dim=-1)
instance_dict = {}
# instance_dict['orig_size'] = data_batch['ori_shape'][img_id]
instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation
instance_dict['boxes'] = instance_body_bbox.float()
data_batch_coco.append(instance_dict)
input_img = NestedTensor(img_list, masks)
return input_img, data_batch_coco
def keypoints_to_scaled_bbox_bfh(
self, keypoints, occ=None,
body_scale=1.0, fh_scale=1.0,
convention='smplx'):
'''Obtain scaled bbox in xyxy format given keypoints
Args:
keypoints (np.ndarray): Keypoints
scale (float): Bounding Box scale
Returns:
bbox_xyxy (np.ndarray): Bounding box in xyxy format
'''
bboxs = []
# supported kps.shape: (1, n, k) or (n, k), k = 2 or 3
if keypoints.ndim == 3:
keypoints = keypoints[0]
if keypoints.shape[-1] != 2:
keypoints = keypoints[:, :2]
for body_part in ['body', 'head', 'left_hand', 'right_hand']:
if body_part == 'body':
scale = body_scale
kps = keypoints
else:
scale = fh_scale
kp_id = get_keypoint_idxs_by_part(body_part, convention=convention)
kps = keypoints[kp_id]
if not occ is None:
occ_p = occ[kp_id]
if np.sum(occ_p) / len(kp_id) >= 0.1:
conf = 0
# print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip')
else:
# print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}')
conf = 1
else:
conf = 1
if body_part == 'body':
conf = 1
xmin, ymin = np.amin(kps, axis=0)
xmax, ymax = np.amax(kps, axis=0)
width = (xmax - xmin) * scale
height = (ymax - ymin) * scale
x_center = 0.5 * (xmax + xmin)
y_center = 0.5 * (ymax + ymin)
xmin = x_center - 0.5 * width
xmax = x_center + 0.5 * width
ymin = y_center - 0.5 * height
ymax = y_center + 0.5 * height
bbox = np.stack([xmin, ymin, xmax, ymax, conf], axis=0).astype(np.float32)
bboxs.append(bbox)
return bboxs
@MODULE_BUILD_FUNCS.registe_with_name(module_name='aios_smplx_box')
def build_aios_smplx_box(args, cfg):
# pdb.set_trace()
num_classes = args.num_classes # 2
device = torch.device(args.device)
backbone = build_backbone(args)
transformer = build_transformer(args)
dn_labelbook_size = args.dn_labelbook_size
dec_pred_class_embed_share = args.dec_pred_class_embed_share
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
if args.eval:
body_model = args.body_model_test
train = False
else:
body_model = args.body_model_train
train = True
model = AiOSSMPLX_Box(
backbone,
transformer,
num_classes=num_classes, # 2
num_queries=args.num_queries, # 900
aux_loss=True,
iter_update=True,
query_dim=4,
random_refpoints_xy=args.random_refpoints_xy, # False
fix_refpoints_hw=args.fix_refpoints_hw, # -1
num_feature_levels=args.num_feature_levels, # 4
nheads=args.nheads, # 8
dec_pred_class_embed_share=dec_pred_class_embed_share, # false
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # False
# two stage
two_stage_type=args.two_stage_type,
# box_share
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, # False
two_stage_class_embed_share=args.two_stage_class_embed_share, # False
dn_number=args.dn_number if args.use_dn else 0, # 100
dn_box_noise_scale=args.dn_box_noise_scale, # 0.4
dn_label_noise_ratio=args.dn_label_noise_ratio, # 0.5
dn_batch_gt_fuse=args.dn_batch_gt_fuse, # false
dn_attn_mask_type_list=args.dn_attn_mask_type_list,
dn_labelbook_size=dn_labelbook_size, # 100
cls_no_bias=args.cls_no_bias, # False
num_group=args.num_group, # 100
num_body_points=0, # 17
num_hand_points=0, # 17
num_face_points=0, # 17
num_box_decoder_layers=args.num_box_decoder_layers, # 2
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers,
# smpl_convention=convention
body_model=body_model,
train=train,
inference=args.inference)
matcher = build_matcher(args)
# prepare weight dict
weight_dict = {
'loss_ce': args.cls_loss_coef, # 2
# bbox
'loss_body_bbox': args.body_bbox_loss_coef, # 5
'loss_rhand_bbox': args.rhand_bbox_loss_coef, # 5
'loss_lhand_bbox': args.lhand_bbox_loss_coef, # 5
'loss_face_bbox': args.face_bbox_loss_coef, # 5
# bbox giou
'loss_body_giou': args.body_giou_loss_coef, # 2
'loss_rhand_giou': args.rhand_giou_loss_coef, # 2
'loss_lhand_giou': args.lhand_giou_loss_coef, # 2
'loss_face_giou': args.face_giou_loss_coef, # 2
# smpl param
'loss_smpl_pose_root': args.smpl_pose_loss_root_coef, # 0
'loss_smpl_pose_body': args.smpl_pose_loss_body_coef, # 0
'loss_smpl_pose_lhand': args.smpl_pose_loss_lhand_coef, # 0
'loss_smpl_pose_rhand': args.smpl_pose_loss_rhand_coef, # 0
'loss_smpl_pose_jaw': args.smpl_pose_loss_jaw_coef, # 0
'loss_smpl_beta': args.smpl_beta_loss_coef, # 0
'loss_smpl_expr': args.smpl_expr_loss_coef,
# smpl kp3d ra
'loss_smpl_body_kp3d_ra': args.smpl_body_kp3d_ra_loss_coef, # 0
'loss_smpl_lhand_kp3d_ra': args.smpl_lhand_kp3d_ra_loss_coef, # 0
'loss_smpl_rhand_kp3d_ra': args.smpl_rhand_kp3d_ra_loss_coef, # 0
'loss_smpl_face_kp3d_ra': args.smpl_face_kp3d_ra_loss_coef, # 0
# smpl kp3d
'loss_smpl_body_kp3d': args.smpl_body_kp3d_loss_coef, # 0
'loss_smpl_face_kp3d': args.smpl_face_kp3d_loss_coef, # 0
'loss_smpl_lhand_kp3d': args.smpl_lhand_kp3d_loss_coef, # 0
'loss_smpl_rhand_kp3d': args.smpl_rhand_kp3d_loss_coef, # 0
# smpl kp2d
'loss_smpl_body_kp2d': args.smpl_body_kp2d_loss_coef, # 0
'loss_smpl_lhand_kp2d': args.smpl_lhand_kp2d_loss_coef, # 0
'loss_smpl_rhand_kp2d': args.smpl_rhand_kp2d_loss_coef, # 0
'loss_smpl_face_kp2d': args.smpl_face_kp2d_loss_coef, # 0
}
clean_weight_dict_wo_dn = copy.deepcopy(weight_dict)
if args.use_dn:
weight_dict.update({
'dn_loss_ce':
args.dn_label_coef, # 0.3
'dn_loss_bbox':
args.bbox_loss_coef * args.dn_bbox_coef, # 5 * 0.5
'dn_loss_giou':
args.giou_loss_coef * args.dn_bbox_coef, # 2 * 0.5
})
clean_weight_dict = copy.deepcopy(weight_dict)
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1): # from 0 t 4 # ???
for k, v in clean_weight_dict.items():
if i < args.num_box_decoder_layers and ('keypoints' in k or 'oks' in k):
continue
if i < args.num_box_decoder_layers and k in [
'loss_rhand_bbox', 'loss_lhand_bbox', 'loss_face_bbox',
'loss_rhand_giou', 'loss_lhand_giou', 'loss_face_giou']:
continue
if i < args.num_hand_face_decoder_layers and k in [
'loss_rhand_keypoints', 'loss_lhand_keypoints',
'loss_face_keypoints', 'loss_rhand_oks',
'loss_lhand_oks', 'loss_face_oks']:
continue
if i < args.num_box_decoder_layers and 'smpl' in k:
continue
aux_weight_dict.update({k + f'_{i}': v})
weight_dict.update(aux_weight_dict)
if args.two_stage_type != 'no':
interm_weight_dict = {}
try:
no_interm_box_loss = args.no_interm_box_loss
except:
no_interm_box_loss = False
_coeff_weight_dict = {
'loss_ce': 1.0,
# bbox
'loss_body_bbox': 1.0 if not no_interm_box_loss else 0.0,
'loss_rhand_bbox': 1.0 if not no_interm_box_loss else 0.0,
'loss_lhand_bbox': 1.0 if not no_interm_box_loss else 0.0,
'loss_face_bbox': 1.0 if not no_interm_box_loss else 0.0,
# bbox giou
'loss_body_giou': 1.0 if not no_interm_box_loss else 0.0,
'loss_rhand_giou': 1.0 if not no_interm_box_loss else 0.0,
'loss_lhand_giou': 1.0 if not no_interm_box_loss else 0.0,
'loss_face_giou': 1.0 if not no_interm_box_loss else 0.0,
# smpl param
'loss_smpl_pose_root': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_pose_body': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_pose_lhand': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_pose_rhand': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_pose_jaw': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_beta': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_expr': 1.0 if not no_interm_box_loss else 0.0,
# smpl kp3d ra
'loss_smpl_body_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_lhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_rhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_face_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0,
# smpl kp3d
'loss_smpl_body_kp3d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_face_kp3d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_lhand_kp3d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_rhand_kp3d': 1.0 if not no_interm_box_loss else 0.0,
# smpl kp2d
'loss_smpl_body_kp2d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_lhand_kp2d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_rhand_kp2d': 1.0 if not no_interm_box_loss else 0.0,
'loss_smpl_face_kp2d': 1.0 if not no_interm_box_loss else 0.0,
}
try:
interm_loss_coef = args.interm_loss_coef # 1
except:
interm_loss_coef = 1.0
interm_weight_dict.update({
k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k]
for k, v in clean_weight_dict_wo_dn.items() if 'keypoints' not in k
})
weight_dict.update(interm_weight_dict)
interm_weight_dict.update({
k + f'_query_expand': v * interm_loss_coef * _coeff_weight_dict[k]
for k, v in clean_weight_dict_wo_dn.items()
}) # ???
weight_dict.update(interm_weight_dict)
losses = cfg.losses
if args.dn_number > 0:
losses += ['dn_label', 'dn_bbox']
losses += ['matching']
criterion = SetCriterion_Box(
num_classes,
matcher=matcher,
weight_dict=weight_dict,
focal_alpha=args.focal_alpha,
losses=losses,
num_box_decoder_layers=args.num_box_decoder_layers,
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers,
num_body_points=0,
num_hand_points=0,
num_face_points=0,
)
criterion.to(device)
if args.inference:
postprocessors = {
'bbox':
PostProcess_SMPLX_Multi_Infer_Box(
num_select=args.num_select,
nms_iou_threshold=args.nms_iou_threshold,
num_body_points=0),
}
else:
postprocessors = {
'bbox':
PostProcess_SMPLX_Multi_Box(
num_select=args.num_select,
nms_iou_threshold=args.nms_iou_threshold,
num_body_points=0),
}
postprocessors_aios = {
'bbox':
PostProcess_aios(num_select=args.num_select,
nms_iou_threshold=args.nms_iou_threshold,
num_body_points=0),
}
return model, criterion, postprocessors, postprocessors_aios