AiOS / models /aios /postprocesses.py
ttxskk
update
d7e58f0
raw
history blame
86.8 kB
import copy
import os
import math
from scipy.optimize import linear_sum_assignment
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.ops.boxes import nms
from torch import Tensor
from pycocotools.coco import COCO
from util import box_ops
from util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy,
get_world_size, interpolate,
is_dist_avail_and_initialized, inverse_sigmoid)
from detrsmpl.utils.demo_utils import convert_verts_to_cam_coord, xywh2xyxy, xyxy2xywh
import numpy as np
from detrsmpl.core.conventions.keypoints_mapping import convert_kps
from detrsmpl.models.body_models.builder import build_body_model
from detrsmpl.utils.geometry import batch_rodrigues, project_points, weak_perspective_projection,project_points_new
from util.human_models import smpl_x
from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idx
class PostProcess(nn.Module):
"""This module converts the model's output into the format expected by the
coco api."""
def __init__(self,
num_select=100,
nms_iou_threshold=-1,
num_body_points=17,
body_model=None) -> None:
super().__init__()
self.num_select = num_select
self.nms_iou_threshold = nms_iou_threshold
self.num_body_points = num_body_points
self.body_model = build_body_model(
dict(type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
keypoint_approximate=True,
joints_regressor=
'data/body_models/J_regressor_h36m.npy'))
@torch.no_grad()
def forward(self,
outputs,
target_sizes,
targets,
data_batch_nc,
device,
not_to_xyxy=False,
test=False):
# import pdb; pdb.set_trace()
num_select = self.num_select
self.body_model.to(device)
out_logits, out_bbox, out_keypoints= \
outputs['pred_logits'], outputs['pred_boxes'], \
outputs['pred_keypoints']
out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d = \
outputs['pred_smpl_pose'], outputs['pred_smpl_beta'], \
outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d']
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = \
torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1)
scores = topk_values
# bbox
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
if not_to_xyxy:
boxes = out_bbox
else:
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
if test:
assert not not_to_xyxy
boxes[:, :, 2:] = boxes[:, :, 2:] - boxes[:, :, :2]
boxes_norm = torch.gather(boxes, 1,
topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
target_sizes = target_sizes.type_as(boxes)
# from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes_norm * scale_fct[:, None, :]
# keypoints
topk_keypoints = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
keypoints = torch.gather(
out_keypoints, 1,
topk_keypoints.unsqueeze(-1).repeat(1, 1,
self.num_body_points * 3))
Z_pred = keypoints[:, :, :(self.num_body_points * 2)]
V_pred = keypoints[:, :, (self.num_body_points * 2):]
img_h, img_w = target_sizes.unbind(1)
Z_pred = Z_pred * torch.stack([img_w, img_h], dim=1).repeat(
1, self.num_body_points)[:, None, :]
keypoints_res = torch.zeros_like(keypoints)
keypoints_res[..., 0::3] = Z_pred[..., 0::2]
keypoints_res[..., 1::3] = Z_pred[..., 1::2]
keypoints_res[..., 2::3] = V_pred[..., 0::1]
# smpl out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d
topk_smpl = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
smpl_pose = torch.gather(
out_smpl_pose, 1, topk_smpl[:, :, None, None,
None].repeat(1, 1, 24, 3, 3))
smpl_beta = torch.gather(out_smpl_beta, 1,
topk_smpl[:, :, None].repeat(1, 1, 10))
smpl_cam = torch.gather(out_smpl_cam, 1,
topk_smpl[:, :, None].repeat(1, 1, 3))
smpl_kp3d = torch.gather(
out_smpl_kp3d, 1,
topk_smpl[:, :, None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],
3))
if False:
import cv2
import mmcv
img = cv2.imread(data_batch_nc['img_metas'][0]['image_path'])
render_img = mmcv.imshow_bboxes(img.copy(),
boxes[0][:3].cpu().numpy(),
show=False)
cv2.imwrite('r_bbox.png', render_img)
gt_bbox_xyxy = xywh2xyxy(
data_batch_nc['bbox_xywh'][0].cpu().numpy())
render_img = mmcv.imshow_bboxes(img.copy(),
gt_bbox_xyxy,
show=False)
cv2.imwrite('r_bbox.png', render_img)
from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d
visualize_kp3d(smpl_kp3d[0][[0]].cpu().numpy(),
output_path='.',
data_source='smpl_54')
from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
visualize_kp2d(keypoints_res[0].reshape(-1, 17,
3)[[0]].cpu().numpy(),
output_path='.',
image_array=img.copy()[None],
data_source='coco',
overwrite=True)
tgt_smpl_kp3d = data_batch_nc['keypoints3d_smpl']
tgt_smpl_pose = [
torch.concat([
data_batch_nc['smpl_global_orient'][i][:, None],
data_batch_nc['smpl_body_pose'][i]
],
dim=-2)
for i in range(len(data_batch_nc['smpl_body_pose']))
]
tgt_smpl_beta = data_batch_nc['smpl_betas']
tgt_keypoints = data_batch_nc['keypoints2d_ori']
tgt_bbox = data_batch_nc['bbox_xywh']
indices = []
# pred
pred_smpl_kp3d = []
pred_smpl_pose = []
pred_smpl_beta = []
pred_scores = []
pred_labels = []
pred_boxes = []
pred_keypoints = []
pred_smpl_cam = []
# gt
gt_smpl_kp3d = []
gt_smpl_pose = []
gt_smpl_beta = []
gt_boxes = []
gt_keypoints = []
image_idx = []
results = []
for i, kp3d in enumerate(tgt_smpl_kp3d):
# kp3d
conf = tgt_smpl_kp3d[i][..., [3]]
gt_kp3d = tgt_smpl_kp3d[i][..., :3]
pred_kp3d = smpl_kp3d[i]
gt_output = self.body_model(
betas=tgt_smpl_beta[i].float(),
body_pose=tgt_smpl_pose[i][:, 1:].float().reshape(-1, 69),
global_orient=tgt_smpl_pose[i][:, [0]].float().reshape(-1, 3),
gender=torch.zeros(tgt_smpl_beta[i].shape[0]),
pose2rot=True)
gt_kp3d = gt_output['joints']
# gt_kp3d,_ = convert_kps(
# gt_kp3d,
# src='smpl_54',
# dst='h36m',
# )
assert gt_kp3d.shape[-2] == 17
H36M_TO_J17 = [
6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9
]
H36M_TO_J14 = H36M_TO_J17[:14]
joint_mapper = H36M_TO_J14
pred_pelvis = pred_kp3d[:, 0]
gt_pelvis = gt_kp3d[:, 0]
gt_keypoints3d = gt_kp3d[:, joint_mapper, :]
pred_keypoints3d = pred_kp3d[:, joint_mapper, :]
pred_keypoints3d = (pred_keypoints3d -
pred_pelvis[:, None, :]) * 1000
gt_keypoints3d = (gt_keypoints3d - gt_pelvis[:, None, :]) * 1000
cost_kp3d = torch.abs((pred_keypoints3d[:, None] -
gt_keypoints3d[None])).sum([-2, -1])
tgt_bbox[i][..., 2] = tgt_bbox[i][..., 0] + tgt_bbox[i][..., 2]
tgt_bbox[i][..., 3] = tgt_bbox[i][..., 1] + tgt_bbox[i][..., 3]
gt_bbox = tgt_bbox[i][..., :4].float()
pred_bbox = boxes[i]
# box_iou = box_ops.box_iou(pred_bbox,gt_bbox)[0]
cost_giou = -box_ops.generalized_box_iou(pred_bbox, gt_bbox)
indice = linear_sum_assignment(cost_giou.cpu())
pred_ind, gt_ind = indice
indices.append(indice)
# bbox
# cost_bbox = torch.cdist(pred_bbox, gt_bbox, p=1)
# indice = linear_sum_assignment(cost_giou.cpu())
# pred_ind, gt_ind = indice
# indices.append(indice)
# pred
pred_scores.append(scores[i][pred_ind].detach().cpu().numpy())
pred_labels.append(labels[i][pred_ind].detach().cpu().numpy())
pred_boxes.append(boxes[i][pred_ind].detach().cpu().numpy())
pred_keypoints.append(
keypoints_res[i][pred_ind].detach().cpu().numpy())
pred_smpl_kp3d.append(
smpl_kp3d[i][pred_ind].detach().cpu().numpy())
pred_smpl_pose.append(
smpl_pose[i][pred_ind].detach().cpu().numpy())
pred_smpl_beta.append(
smpl_beta[i][pred_ind].detach().cpu().numpy())
pred_smpl_cam.append(smpl_cam[i][pred_ind].detach().cpu().numpy())
# gt
gt_smpl_kp3d.append(
tgt_smpl_kp3d[i][gt_ind].detach().cpu().numpy())
gt_smpl_pose.append(
tgt_smpl_pose[i][gt_ind].detach().cpu().numpy())
gt_smpl_beta.append(
tgt_smpl_beta[i][gt_ind].detach().cpu().numpy())
gt_boxes.append(tgt_bbox[i][gt_ind].detach().cpu().numpy())
gt_keypoints.append(
tgt_keypoints[i][gt_ind].detach().cpu().numpy())
image_idx.append(targets[i]['image_id'].detach().cpu().numpy())
# gt_output = self.body_model(
# betas=tgt_smpl_beta[i].float(),
# body_pose=tgt_smpl_pose[i][:,1:].float().reshape(-1, 69),
# global_orient=tgt_smpl_pose[i][:,[0]].float().reshape(-1, 3),
# pose2rot=True
# )
results.append({
'scores': pred_scores,
'labels': pred_labels,
'boxes': pred_boxes,
'keypoints': pred_keypoints,
'pred_smpl_pose': pred_smpl_pose,
'pred_smpl_beta': pred_smpl_beta,
'pred_smpl_cam': pred_smpl_cam,
'pred_smpl_kp3d': pred_smpl_kp3d,
'gt_smpl_pose': gt_smpl_pose,
'gt_smpl_beta': gt_smpl_beta,
'gt_smpl_kp3d': gt_smpl_kp3d,
'gt_boxes': gt_bbox,
'gt_keypoints': gt_keypoints,
'image_idx': image_idx,
})
# results.append({
# 'scores': scores[i][pred_ind],
# 'labels': labels[i][pred_ind],
# 'boxes': boxes[i][pred_ind],
# 'keypoints': keypoints_res[i][pred_ind],
# 'pred_smpl_pose': smpl_pose[i][pred_ind],
# 'pred_smpl_beta': tgt_smpl_beta[i][gt_ind],
# 'pred_smpl_cam': smpl_cam[i][pred_ind],
# 'pred_smpl_kp3d': smpl_kp3d[i][pred_ind],
# 'gt_smpl_pose': tgt_smpl_pose[i][gt_ind],
# 'gt_smpl_beta': tgt_smpl_beta[i][gt_ind],
# 'gt_smpl_kp3d': tgt_smpl_kp3d[i][gt_ind],
# 'gt_boxes': tgt_bbox[i][gt_ind],
# 'gt_keypoints': tgt_keypoints[i][gt_ind],
# 'image_idx': targets[i]['image_id'],
# }
# )
if self.nms_iou_threshold > 0:
raise NotImplementedError
item_indices = [
nms(b, s, iou_threshold=self.nms_iou_threshold)
for b, s in zip(boxes, scores)
]
# import pdb; pdb.set_trace()
results = [{
'scores': s[i],
'labels': l[i],
'boxes': b[i]
} for s, l, b, i in zip(scores, labels, boxes, item_indices)]
else:
results = results
return results
class PostProcess_aios(nn.Module):
"""This module converts the model's output into the format expected by the
coco api."""
def __init__(self,
num_select=100,
nms_iou_threshold=-1,
num_body_points=17) -> None:
super().__init__()
self.num_select = num_select
self.nms_iou_threshold = nms_iou_threshold
self.num_body_points = num_body_points
@torch.no_grad()
def forward(self, outputs, target_sizes, not_to_xyxy=False, test=False):
num_select = self.num_select
out_logits, out_bbox, out_keypoints = outputs['pred_logits'], outputs[
'pred_boxes'], outputs['pred_keypoints']
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(
out_logits.shape[0], -1),
num_select,
dim=1)
scores = topk_values
# bbox
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
if not_to_xyxy:
boxes = out_bbox
else:
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
if test:
assert not not_to_xyxy
boxes[:, :, 2:] = boxes[:, :, 2:] - boxes[:, :, :2]
boxes = torch.gather(boxes, 1,
topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
# from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
# keypoints
topk_keypoints = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
keypoints = torch.gather(
out_keypoints, 1,
topk_keypoints.unsqueeze(-1).repeat(1, 1,
self.num_body_points * 3))
Z_pred = keypoints[:, :, :(self.num_body_points * 2)]
V_pred = keypoints[:, :, (self.num_body_points * 2):]
img_h, img_w = target_sizes.unbind(1)
Z_pred = Z_pred * torch.stack([img_w, img_h], dim=1).repeat(
1, self.num_body_points)[:, None, :]
keypoints_res = torch.zeros_like(keypoints)
keypoints_res[..., 0::3] = Z_pred[..., 0::2]
keypoints_res[..., 1::3] = Z_pred[..., 1::2]
keypoints_res[..., 2::3] = V_pred[..., 0::1]
if self.nms_iou_threshold > 0:
raise NotImplementedError
item_indices = [
nms(b, s, iou_threshold=self.nms_iou_threshold)
for b, s in zip(boxes, scores)
]
# import ipdb; ipdb.set_trace()
results = [{
'scores': s[i],
'labels': l[i],
'boxes': b[i]
} for s, l, b, i in zip(scores, labels, boxes, item_indices)]
else:
results = [{
'scores': s,
'labels': l,
'boxes': b,
'keypoints': k
} for s, l, b, k in zip(scores, labels, boxes, keypoints_res)]
return results
class PostProcess_SMPLX(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(
self,
num_select=100,
nms_iou_threshold=-1,
num_body_points=17,
body_model= dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
keypoint_dst='smplx_137',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True)
) -> None:
super().__init__()
self.num_select = num_select
self.nms_iou_threshold = nms_iou_threshold
self.num_body_points=num_body_points
self.body_model = build_body_model(body_model)
@torch.no_grad()
def forward(self, outputs, target_sizes, targets, data_batch_nc, not_to_xyxy=False, test=False):
# import pdb; pdb.set_trace()
num_select = self.num_select
out_logits, out_bbox, out_keypoints= \
outputs['pred_logits'], outputs['pred_boxes'], \
outputs['pred_keypoints']
out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \
outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \
outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts']
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = \
torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1)
scores = topk_values
# bbox
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
if not_to_xyxy:
boxes = out_bbox
else:
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
if test:
assert not not_to_xyxy
boxes[:,:,2:] = boxes[:,:,2:] - boxes[:,:,:2]
boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
target_sizes = target_sizes.type_as(boxes)
# from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes_norm * scale_fct[:, None, :]
# keypoints
topk_keypoints = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
keypoints = torch.gather(out_keypoints, 1, topk_keypoints.unsqueeze(-1).repeat(1, 1, self.num_body_points*3))
Z_pred = keypoints[:, :, :(self.num_body_points*2)]
V_pred = keypoints[:, :, (self.num_body_points*2):]
img_h, img_w = target_sizes.unbind(1)
Z_pred = Z_pred * torch.stack([img_w, img_h], dim=1).repeat(1, self.num_body_points)[:, None, :]
keypoints_res = torch.zeros_like(keypoints)
keypoints_res[..., 0::3] = Z_pred[..., 0::2]
keypoints_res[..., 1::3] = Z_pred[..., 1::2]
keypoints_res[..., 2::3] = V_pred[..., 0::1]
# smpl out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d
topk_smpl = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159))
smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3))
smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3))
smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3))
if False:
import cv2
import mmcv
import ipdb;ipdb.set_trace()
img = (data_batch_nc['img'][1].permute(1,2,0)*255).int().detach().cpu().numpy()
# img = cv2.imread(data_batch_nc['img_metas'][1]['image_path'])
tgt_bbox_center = torch.stack(data_batch_nc['body_bbox_center'])
tgt_bbox_size = torch.stack(data_batch_nc['body_bbox_size']).cpu().numpy()
tgt_bbox = torch.cat([tgt_bbox_center-tgt_bbox_size/2,tgt_bbox_center+tgt_bbox_size/2],dim=-1)
tgt_img_shape = data_batch_nc['img_shape']
bbox = tgt_bbox.cpu().numpy()*(tgt_img_shape.repeat(1,2).cpu().numpy()[:,::-1])
render_img = mmcv.imshow_bboxes(img.copy(), boxes[1][:3].cpu().numpy(), show=False)
cv2.imwrite('r_bbox.png',render_img)
render_img = mmcv.imshow_bboxes(img.copy(), bbox, show=False)
# cv2.imwrite('r_bbox.png',render_img)
# from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d
# visualize_kp3d(smpl_kp3d[1][[0]].cpu().numpy(),output_path='.',data_source='smpl_54')
from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
import ipdb;ipdb.set_trace()
visualize_kp2d(keypoints_res[0].reshape(-1,17,3)[[3]].cpu().numpy(), output_path='.', image_array=img.copy()[None], data_source='coco',overwrite=True)
# TODO: align it with agora
tgt_smpl_kp3d = data_batch_nc['joint_cam']
tgt_smpl_kp3d_conf = data_batch_nc['joint_valid']
tgt_smpl_pose = data_batch_nc['smplx_pose']
tgt_smpl_beta = data_batch_nc['smplx_shape']
tgt_smpl_expr = data_batch_nc['smplx_expr']
tgt_keypoints = data_batch_nc['joint_img']
tgt_img_shape = data_batch_nc['img_shape']
tgt_ann_idx = data_batch_nc['ann_idx']
# tgt_img_path = data_batch_nc['img_shape']
tgt_bbox_center = torch.stack(data_batch_nc['body_bbox_center'])
tgt_bbox_size = torch.stack(data_batch_nc['body_bbox_size'])
tgt_bbox = torch.cat([tgt_bbox_center-tgt_bbox_size/2,tgt_bbox_size],dim=-1)
tgt_bbox = tgt_bbox * scale_fct
tgt_verts = data_batch_nc['smplx_mesh_cam']
tgt_bb2img_trans = data_batch_nc['bb2img_trans']
indices = []
# pred
pred_smpl_kp3d = []
pred_smpl_pose = []
pred_smpl_beta = []
pred_smpl_verts = []
pred_smpl_expr = []
pred_scores = []
pred_labels = []
pred_boxes = []
pred_keypoints = []
pred_smpl_cam = []
# gt
gt_smpl_kp3d = []
gt_smpl_pose = []
gt_smpl_beta = []
gt_smpl_expr = []
gt_smpl_verts = []
gt_boxes = []
gt_keypoints = []
gt_bb2img_trans = []
image_idx = []
results = []
for i, kp3d in enumerate(tgt_smpl_kp3d):
# kp3d
conf = tgt_smpl_kp3d_conf[i][...,]
gt_kp3d = tgt_smpl_kp3d[i][...,:3]
pred_kp3d = smpl_kp3d[i]
pred_kp3d_match,_ = convert_kps(pred_kp3d,'smplx','smplx_137')
# pred_kp3d_match = pred_kp3d
cost_kp3d = torch.abs((pred_kp3d_match[:,None] -
gt_kp3d[None])* conf[None]).sum([-2,-1])
# bbox
tgt_bbox[i][...,2] = tgt_bbox[i][...,0] + tgt_bbox[i][...,2]
tgt_bbox[i][...,3] = tgt_bbox[i][...,1] + tgt_bbox[i][...,3]
gt_bbox = tgt_bbox[i][..., :4][None].float()
pred_bbox = boxes[i]
# box_iou = box_ops.box_iou(pred_bbox,gt_bbox)[0]
cost_giou = -box_ops.generalized_box_iou(pred_bbox,gt_bbox)
# cost_bbox = torch.cdist(pred_bbox, gt_bbox, p=1)
indice = linear_sum_assignment(cost_kp3d.cpu())
pred_ind, gt_ind = indice
indices=(indice)
# pred
pred_scores=(scores[i][pred_ind].detach().cpu().numpy())
pred_labels=(labels[i][pred_ind].detach().cpu().numpy())
pred_boxes=(boxes[i][pred_ind].detach().cpu().numpy())
pred_keypoints=(keypoints_res[i][pred_ind].detach().cpu().numpy())
pred_smpl_kp3d=(smpl_kp3d[i][pred_ind].detach().cpu().numpy())
pred_smpl_pose=(smpl_pose[i][pred_ind].detach().cpu().numpy())
pred_smpl_beta=(smpl_beta[i][pred_ind].detach().cpu().numpy())
pred_smpl_cam=(smpl_cam[i][pred_ind].detach().cpu().numpy())
pred_smpl_expr=(smpl_expr[i][pred_ind].detach().cpu().numpy())
pred_smpl_verts=(smpl_verts[i][pred_ind].detach().cpu().numpy())
# gt
# gt_smpl_kp3d=(tgt_smpl_kp3d[i][gt_ind].detach().cpu().numpy())
# gt_smpl_pose=(tgt_smpl_pose[i][gt_ind].detach().cpu().numpy())
# gt_smpl_beta=(tgt_smpl_beta[i][gt_ind].detach().cpu().numpy())
# gt_boxes=(tgt_bbox[i][gt_ind].detach().cpu().numpy())
# gt_smpl_expr=(tgt_smpl_expr[i][gt_ind].detach().cpu().numpy())
# gt_smpl_verts=(tgt_verts[i][gt_ind].detach().cpu().numpy())
# gt_keypoints=(tgt_keypoints[i][gt_ind].detach().cpu().numpy())
# gt_bb2img_trans=(tgt_bb2img_trans[i][gt_ind].detach().cpu().numpy())
gt_smpl_kp3d=(tgt_smpl_kp3d[i].detach().cpu().numpy())
gt_smpl_pose=(tgt_smpl_pose[i].detach().cpu().numpy())
gt_smpl_beta=(tgt_smpl_beta[i].detach().cpu().numpy())
gt_boxes=(tgt_bbox[i].detach().cpu().numpy())
gt_smpl_expr=(tgt_smpl_expr[i].detach().cpu().numpy())
gt_smpl_verts=(tgt_verts[i].detach().cpu().numpy())
gt_ann_idx=(tgt_ann_idx[i].detach().cpu().numpy())
gt_keypoints=(tgt_keypoints[i].detach().cpu().numpy())
gt_img_shape=(tgt_img_shape[i].detach().cpu().numpy())
gt_bb2img_trans=(tgt_bb2img_trans[i].detach().cpu().numpy())
if 'image_id' in targets[i]:
image_idx=(targets[i]['image_id'].detach().cpu().numpy())
# pred_smpl_pose = np.concatenate(pred_smpl_pose,axis = 0)
# gt_bb2img_trans = np.concatenate(gt_bb2img_trans,axis = 0)
# gt_smpl_verts = np.concatenate(gt_smpl_verts,axis = 0)
# pred_smpl_verts = np.concatenate(pred_smpl_verts, axis = 0)
# pred_smpl_cam = np.concatenate(pred_smpl_cam, axis = 0)
# import ipdb;ipdb.set_trace()
smplx_root_pose = pred_smpl_pose[:,:3]
smplx_body_pose = pred_smpl_pose[:,3:66]
smplx_lhand_pose = pred_smpl_pose[:,66:111]
smplx_rhand_pose = pred_smpl_pose[:,111:156]
smplx_jaw_pose = pred_smpl_pose[:,156:]
# pred_smpl_kp3d = np.concatenate(pred_smpl_kp3d,axis = 0)
pred_smpl_cam = torch.Tensor(pred_smpl_cam)
pred_smpl_kp3d = torch.Tensor(pred_smpl_kp3d)
# pred_smpl_kp2d = weak_perspective_projection(pred_smpl_kp3d, scale=pred_smpl_cam[:, :1], translation=pred_smpl_cam[:, 1:3])
# pred_smpl_verts2d = weak_perspective_projection(pred_smpl_kp3d, scale=pred_smpl_cam[:, :1], translation=pred_smpl_cam[:, 1:3])
img_wh = tgt_img_shape[i].flip(-1)[None]
pred_smpl_kp2d = project_points_new(
points_3d=pred_smpl_kp3d,
pred_cam=pred_smpl_cam,
focal_length=5000,
camera_center=img_wh/2
)
pred_smpl_kp2d = pred_smpl_kp2d.numpy()
pred_smpl_cam = pred_smpl_cam.numpy()
# cam_trans = get_camera_trans(pred_smpl_cam)
# pred_smpl_kp2d = (pred_smpl_kp2d+1)/2
# pred_smpl_kp2d[:, :,0] = pred_smpl_kp2d[:, :, 0] * gt_img_shape[1]
# pred_smpl_kp2d[:, :, 1] = pred_smpl_kp2d[:, :, 1] * gt_img_shape[0]
# # joint_proj = np.dot(out['bb2img_trans'], joint_proj.transpose(1, 0)).transpose(1, 0)
# # joint_proj[:, 0] = joint_proj[:, 0] / self.resolution[1] * 3840 # restore to original resolution
# # joint_proj[:, 1] = joint_proj[:, 1] / self.resolution[0] * 2160 # restore to original resolution
vis = False
if vis:
from pytorch3d.io import save_obj
from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl
from detrsmpl.utils.demo_utils import get_default_hmr_intrinsic
# img = (data_batch_nc['img'][i]*255).permute(1,2,0).int().detach().cpu().numpy()
# (s, tx, ty) = (pred_smpl_cam[:, 0] + 1e-9), pred_smpl_cam[:, 1], pred_smpl_cam[:, 2]
# depth, dx, dy = 1./s, tx/s, ty/s
# cam_t = np.stack([dx, dy, depth], 1)
# K = torch.Tensor(
# get_default_hmr_intrinsic(focal_length=5000,
# det_height=750,
# det_width=1333))
# render_smpl(verts = pred_smpl_verts+cam_t[:,None,:],
# image_array=img.copy()[None],
# body_model=self.body_model,convention='opencv',
# output_path='.',overwrite=True,K=K)
# save_obj(
# 'pred.obj',
# torch.tensor(pred_smpl_verts[0]),
# torch.tensor(self.body_model.faces.astype(np.float)))
import mmcv
import cv2
import numpy as np
from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl
from detrsmpl.models.body_models.builder import build_body_model
from pytorch3d.io import save_obj
from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d
img = mmcv.imdenormalize(
img=(data_batch_nc['img'][i].cpu().numpy()).transpose(1, 2, 0),
mean=np.array([123.675, 116.28, 103.53]),
std=np.array([58.395, 57.12, 57.375]),
to_bgr=True).astype(np.uint8)
img = mmcv.imshow_bboxes(img,pred_boxes,show=False)
img= visualize_kp2d(pred_smpl_kp2d, output_path='.', image_array=img.copy()[None], data_source='smplx',overwrite=True)[0]
name = str(pred_smpl_kp2d[0,0,0]).replace('.','')
cv2.imwrite('res_vis/%s.png'%name, img)
# # joint_proj = np.dot(out['bb2img_trans'], joint_proj.transpose(1, 0)).transpose(1, 0)
# # joint_proj[:, 0] = joint_proj[:, 0] / self.resolution[1] * 3840 # restore to original resolution
# # joint_proj[:, 1] = joint_proj[:, 1] / self.resolution[0] * 2160 # restore to original resolution
#
# from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
# from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl
# from detrsmpl.models.body_models.builder import build_body_model
# body_model = dict(
# type='smplx',
# keypoint_src='smplx',
# num_expression_coeffs=10,
# keypoint_dst='smplx_137',
# model_path='data/body_models/smplx',
# use_pca=False,
# use_face_contour=True)
# body_modeltest = build_body_model(body_model)
# # device =gt_betas.device
# # body_modeltest.to(device)
# gt_output = body_modeltest(betas=torch.Tensor(gt_smpl_beta[None].reshape(-1, 10)),body_pose=torch.Tensor(gt_smpl_pose[3:66][None].reshape(-1, 21*3)), global_orient=torch.Tensor(gt_smpl_pose[:3][None].reshape(-1, 3)),left_hand_pose=torch.Tensor(gt_smpl_pose[66:111][None].reshape(-1, 15*3)),right_hand_pose=torch.Tensor(gt_smpl_pose[111:156][None].reshape(-1, 15*3)),jaw_pose=torch.Tensor(gt_smpl_pose[156:][None].reshape(-1, 3)),)
# img = (data_batch_nc['img'][i]*255).permute(1,2,0).int().detach().cpu().numpy()
# render_smpl(verts = gt_output['vertices'],image_array=img.copy()[None],body_model=self.body_model,convention='opencv',orig_cam = np.concatenate([pred_smpl_cam[:,:1],pred_smpl_cam[:,:1],pred_smpl_cam[:,1:]],axis=-1),output_path='.',overwrite=True)
# img_new = visualize_smpl_hmr(
# cam_transl=pred_smpl_cam,
# verts = pred_smpl_verts,
# body_model=self.body_model,
# bbox = np.array([0,0,gt_img_shape[1],gt_img_shape[0]]),
# det_width = gt_img_shape[1],
# det_height=gt_img_shape[0],
# image_array=img.copy()[None],
# output_path='.',
# overwrite=True
# )
results.append({
'scores': pred_scores,
'labels': pred_labels,
'boxes': pred_boxes[0],
'keypoints': pred_keypoints[0],
'smplx_root_pose': smplx_root_pose[0],
'smplx_body_pose': smplx_body_pose[0],
'smplx_lhand_pose': smplx_lhand_pose[0],
'smplx_rhand_pose': smplx_rhand_pose[0],
'smplx_jaw_pose': smplx_jaw_pose[0],
'smplx_shape': pred_smpl_beta[0],
'smplx_expr': pred_smpl_expr[0],
'cam_trans': pred_smpl_cam[0],
'smplx_mesh_cam': pred_smpl_verts[0],
'smplx_mesh_cam_target': gt_smpl_verts,
'gt_ann_idx':gt_ann_idx,
'gt_smpl_kp3d':gt_smpl_kp3d,
'smplx_joint_proj': pred_smpl_kp2d[0],
'image_idx': image_idx,
'bb2img_trans': gt_bb2img_trans,
'img_shape': gt_img_shape
})
# results.append({
# 'scores': scores[i][pred_ind],
# 'labels': labels[i][pred_ind],
# 'boxes': boxes[i][pred_ind],
# 'keypoints': keypoints_res[i][pred_ind],
# 'pred_smpl_pose': smpl_pose[i][pred_ind],
# 'pred_smpl_beta': tgt_smpl_beta[i][gt_ind],
# 'pred_smpl_cam': smpl_cam[i][pred_ind],
# 'pred_smpl_kp3d': smpl_kp3d[i][pred_ind],
# 'gt_smpl_pose': tgt_smpl_pose[i][gt_ind],
# 'gt_smpl_beta': tgt_smpl_beta[i][gt_ind],
# 'gt_smpl_kp3d': tgt_smpl_kp3d[i][gt_ind],
# 'gt_boxes': tgt_bbox[i][gt_ind],
# 'gt_keypoints': tgt_keypoints[i][gt_ind],
# 'image_idx': targets[i]['image_id'],
# }
# )
if self.nms_iou_threshold > 0:
raise NotImplementedError
item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)]
# import pdb; pdb.set_trace()
results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)]
else:
results = results
return results
class PostProcess_SMPLX_Multi(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(
self,
num_select=100,
nms_iou_threshold=-1,
num_body_points=17,
body_model= dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
num_betas=10,
gender='neutral',
keypoint_dst='smplx_137',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True,
),
) -> None:
super().__init__()
self.num_select = num_select
self.nms_iou_threshold = nms_iou_threshold
self.num_body_points=num_body_points
# -1 for neutral; 0 for male; 1 for femal
gender_body_model = {}
gender_body_model[-1] = build_body_model(body_model)
body_model['gender']='male'
gender_body_model[0] = build_body_model(body_model)
body_model['gender']='female'
gender_body_model[1] = build_body_model(body_model)
self.body_model = gender_body_model
@torch.no_grad()
def forward(self, outputs, target_sizes, targets, data_batch_nc, not_to_xyxy=False, test=False, dataset = None):
# import pdb; pdb.set_trace()
batch_size = outputs['pred_keypoints'].shape[0]
results = []
device = outputs['pred_keypoints'].device
for body_model in self.body_model.values():
body_model.to(device)
# test with instance num
# num_select=data_batch_nc['joint_img'][0].shape[0]
# num_select = self.num_select
num_select = 1
out_logits, out_bbox, out_keypoints= \
outputs['pred_logits'], outputs['pred_boxes'], \
outputs['pred_keypoints']
out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \
outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \
outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts']
out_smpl_kp2d = []
for bs in range(batch_size):
out_kp3d_i = out_smpl_kp3d[bs]
out_cam_i = out_smpl_cam[bs]
out_img_shape = data_batch_nc['img_shape'][bs].flip(-1)[None]
# out_kp3d_i = out_kp3d_i - out_kp3d_i[:, [0]]
out_kp2d_i = project_points_new(
points_3d=out_kp3d_i,
pred_cam=out_cam_i,
focal_length=5000,
camera_center=out_img_shape/2
)
out_smpl_kp2d.append(out_kp2d_i.detach().cpu().numpy())
out_smpl_kp2d = torch.tensor(out_smpl_kp2d).to(device)
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = \
torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1)
scores = topk_values
# bbox
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
if not_to_xyxy:
boxes = out_bbox
else:
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
if test:
assert not not_to_xyxy
boxes[:,:,2:] = boxes[:,:,2:] - boxes[:,:,:2]
# gather gt bbox
boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
target_sizes = target_sizes.type_as(boxes)
# from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes_norm * scale_fct[:, None, :]
# smplx kp2d
topk_smpl_kp2d = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
pred_smpl_kp2d = torch.gather(
out_smpl_kp2d, 1,
topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 137, 2))
# keypoints
topk_keypoints = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
keypoints = torch.gather(
out_keypoints, 1,
topk_keypoints.unsqueeze(-1).repeat(1, 1, self.num_body_points*3))
Z_pred = keypoints[:, :, :(self.num_body_points * 2)]
V_pred = keypoints[:, :, (self.num_body_points * 2):]
img_h, img_w = target_sizes.unbind(1)
Z_pred = Z_pred * torch.stack([img_w, img_h], dim=1).repeat(1, self.num_body_points)[:, None, :]
keypoints_res = torch.zeros_like(keypoints)
keypoints_res[..., 0::3] = Z_pred[..., 0::2]
keypoints_res[..., 1::3] = Z_pred[..., 1::2]
keypoints_res[..., 2::3] = V_pred[..., 0::1]
# smpl out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d
topk_smpl = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159))
smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3))
smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3))
smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3))
tgt_smpl_kp3d = data_batch_nc['joint_cam']
# tgt_smpl_kp3d_conf = data_batch_nc['joint_valid']
tgt_smpl_pose = data_batch_nc['smplx_pose']
tgt_smpl_beta = data_batch_nc['smplx_shape']
tgt_smpl_expr = data_batch_nc['smplx_expr']
tgt_keypoints = data_batch_nc['joint_img']
tgt_img_shape = data_batch_nc['img_shape']
# tgt_bbox_center = data_batch_nc['body_bbox_center']
# tgt_bbox_size = data_batch_nc['body_bbox_size']
tgt_bb2img_trans = data_batch_nc['bb2img_trans']
tgt_ann_idx = data_batch_nc['ann_idx']
pred_indice_list = []
gt_indice_list = []
tgt_verts = []
tgt_kp3d = []
tgt_bbox = []
for bbox_center, bbox_size, pose, \
beta, expr, gender, gt_kp2d, _, pred_kp2d, pred_kp3d, boxe, scale \
in zip(
data_batch_nc['body_bbox_center'],
data_batch_nc['body_bbox_size'],
# data_batch_nc['bb2img_trans'],
data_batch_nc['smplx_pose'],
data_batch_nc['smplx_shape'],
data_batch_nc['smplx_expr'],
data_batch_nc['gender'],
data_batch_nc['joint_img'],
data_batch_nc['joint_cam'],
# keypoints_res, smpl_kp3d, boxes, scale_fct,
pred_smpl_kp2d, smpl_kp3d, boxes, scale_fct,
):
# build smplx verts
gt_verts = []
gt_kp3d = []
gt_bbox = []
gender_ = gender.cpu().numpy()
for i, g in enumerate(gender_):
gt_out = self.body_model[g](
betas=beta[i].reshape(-1, 10),
global_orient=pose[i, :3].reshape(-1, 3).unsqueeze(1),
body_pose=pose[i, 3:66].reshape(-1, 21 * 3),
left_hand_pose=pose[i, 66:111].reshape(-1, 15 * 3),
right_hand_pose=pose[i, 111:156].reshape(-1, 15 * 3),
jaw_pose=pose[i, 156:159].reshape(-1, 3),
leye_pose=torch.zeros_like(pose[i, 156:159]),
reye_pose=torch.zeros_like(pose[i, 156:159]),
expression=expr[i].reshape(-1, 10),
)
gt_verts.append(gt_out['vertices'][0].detach().cpu().numpy())
gt_kp3d.append(gt_out['joints'][0].detach().cpu().numpy())
tgt_verts.append(gt_verts)
tgt_kp3d.append(gt_kp3d)
# bbox
gt_bbox = torch.cat(
[bbox_center - bbox_size / 2, bbox_size ], dim=-1)
gt_bbox = gt_bbox * scale
# xywh2xyxy
gt_bbox[..., 2] = gt_bbox[..., 0] + gt_bbox[..., 2]
gt_bbox[..., 3] = gt_bbox[..., 1] + gt_bbox[..., 3]
tgt_bbox.append(gt_bbox[..., :4].float())
pred_bbox = boxe.clone()
# box_iou = box_ops.box_iou(pred_bbox,gt_bbox)[0]
cost_giou = -box_ops.generalized_box_iou(pred_bbox, gt_bbox)
cost_bbox = torch.cdist(
box_ops.box_xyxy_to_cxcywh(pred_bbox)/scale,
box_ops.box_xyxy_to_cxcywh(gt_bbox)/scale, p=1)
# smpl kp2d
gt_kp2d_conf = gt_kp2d[:,:,2:3]
gt_kp2d_ = (gt_kp2d[:, :, :2] * scale[:2]) /torch.tensor([12, 16]).to(device)
gt_kp2d_body = gt_kp2d_[:, smpl_x.joint_part['body']]
gt_kp2d_body_conf = gt_kp2d_conf[:, smpl_x.joint_part['body']]
pred_kp2d_body = pred_kp2d[:, smpl_x.joint_part['body']] # smplx kps head
# print(gt_kp2d_body.shape,gt_kp2d_body_conf.shape,pred_kp2d_body.shape,pred_kp2d.shape)
# exit()
# print(gt_kp2d_body_conf.shape)
# exit()
# gt_kp2d_body_conf, _ = convert_kps(gt_kp2d_conf,'smplx_137', 'coco', approximate=True)
# gt_kp2d_body, _ = convert_kps(gt_kp2d_,'smplx_137', 'coco', approximate=True)
# pred_kp2d_body, _ = convert_kps(pred_kp2d,'smplx_137', 'coco', approximate=True)
# cost_keypoints = torch.abs(
# (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])
# ).sum([-2,-1])
# print(dataset.__class__.__name__)
if dataset.__class__.__name__ == 'UBody_MM':
cost_keypoints = torch.abs(
(pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])*gt_kp2d_body_conf[None]
).sum([-2,-1])/gt_kp2d_body_conf[None].sum()
else:
cost_keypoints = torch.abs(
(pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])
).sum([-2,-1])
# smpl kp3d
gt_kp3d_ = torch.tensor(np.array(gt_kp3d) - np.array(gt_kp3d)[:, [0]]).to(device)
pred_kp3d_ = (pred_kp3d - pred_kp3d[:, [0]])
cost_kp3d = torch.abs((pred_kp3d_[:, None] - gt_kp3d_[None])).sum([-2,-1])
# 1. kps
indice = linear_sum_assignment(cost_keypoints.cpu())
# 2. bbox giou
# indice = linear_sum_assignment(cost_giou.cpu())
# 3. bbox
# indice = linear_sum_assignment(cost_bbox.cpu())
# 4. all
# indice = linear_sum_assignment(
# 10* (cost_keypoints).cpu() + 5 * cost_bbox.cpu())
# 5. kp3d
# indice = linear_sum_assignment(cost_kp3d.cpu())
pred_ind, gt_ind = indice
pred_indice_list.append(pred_ind)
gt_indice_list.append(gt_ind)
pred_scores = torch.cat(
[t[i] for t, i in zip(scores, pred_indice_list)]
).detach().cpu().numpy()
pred_labels = torch.cat(
[t[i] for t, i in zip(labels, pred_indice_list)]
).detach().cpu().numpy()
pred_boxes = torch.cat(
[t[i] for t, i in zip(boxes, pred_indice_list)]
).detach().cpu().numpy()
pred_keypoints = torch.cat(
[t[i] for t, i in zip(keypoints_res, pred_indice_list)]
).detach().cpu().numpy()
pred_smpl_kp2d = []
pred_smpl_kp3d = []
pred_smpl_cam = []
img_wh_list = []
for i, img_wh in enumerate(tgt_img_shape):
kp3d = smpl_kp3d[i][pred_indice_list[i]]
cam = smpl_cam[i][pred_indice_list[i]]
img_wh = img_wh.flip(-1)[None]
kp2d = project_points_new(
points_3d=kp3d,
pred_cam=cam,
focal_length=5000,
camera_center=img_wh/2
)
num_instance = kp2d.shape[0]
img_wh_list.append(img_wh.repeat(num_instance,1).cpu().numpy())
pred_smpl_kp2d.append(kp2d.detach().cpu().numpy())
pred_smpl_kp3d.append(kp3d.detach().cpu().numpy())
pred_smpl_cam.append(cam.detach().cpu().numpy())
# pred_smpl_cam = torch.cat(
# [t[i] for t, i in zip(smpl_cam, pred_indice_list)]
# ).detach().cpu().numpy()
# pred_smpl_kp3d = torch.cat(
# [t[i] for t, i in zip(smpl_kp3d, pred_indice_list)]
# )
pred_smpl_pose = torch.cat(
[t[i] for t, i in zip(smpl_pose, pred_indice_list)]
).detach().cpu().numpy()
pred_smpl_beta = torch.cat(
[t[i] for t, i in zip(smpl_beta, pred_indice_list)]
).detach().cpu().numpy()
pred_smpl_expr = torch.cat(
[t[i] for t, i in zip(smpl_expr, pred_indice_list)]
).detach().cpu().numpy()
pred_smpl_verts = torch.cat(
[t[i] for t, i in zip(smpl_verts, pred_indice_list)]
).detach().cpu().numpy()
# from pytorch3d.io import save_obj
# for m_i,(mesh_out_i) in enumerate(smpl_verts[0].detach().cpu()):
# save_obj('temp_smpl_%d.obj'%m_i,verts=(mesh_out_i),faces=torch.tensor([]))
# for m_i,(mesh_out_i) in enumerate(pred_smpl_verts):
# save_obj('temp_pred_%d.obj'%m_i,verts=torch.Tensor(mesh_out_i),faces=torch.tensor([]))
# print(pred_indice_list)
# exit()
pred_smpl_kp2d = np.concatenate(pred_smpl_kp2d, 0)
pred_smpl_kp3d = np.concatenate(pred_smpl_kp3d, 0)
pred_smpl_cam = np.concatenate(pred_smpl_cam, 0)
img_wh_list = np.concatenate(img_wh_list, 0)
gt_smpl_kp3d = torch.cat(tgt_smpl_kp3d).detach().cpu().numpy()
gt_smpl_pose = torch.cat(tgt_smpl_pose).detach().cpu().numpy()
gt_smpl_beta = torch.cat(tgt_smpl_beta).detach().cpu().numpy()
gt_boxes = torch.cat(tgt_bbox).detach().cpu().numpy()
gt_smpl_expr = torch.cat(tgt_smpl_expr).detach().cpu().numpy()
# gt_img_shape = torch.cat(tgt_img_shape).detach().cpu().numpy()
gt_smpl_verts = np.concatenate(
[np.array(t)[i] for t, i in zip(tgt_verts, gt_indice_list)], 0)
gt_ann_idx = torch.cat([t.repeat(len(i)) for t, i in zip(tgt_ann_idx, gt_indice_list)],dim=0).cpu().numpy()
gt_keypoints = torch.cat(tgt_keypoints).detach().cpu().numpy()
# gt_img_shape = tgt_img_shape.detach().cpu().numpy()
gt_bb2img_trans = torch.stack(tgt_bb2img_trans).detach().cpu().numpy()
if 'image_id' in targets[i]:
image_idx=(targets[i]['image_id'].detach().cpu().numpy())
smplx_root_pose = pred_smpl_pose[:,:3]
smplx_body_pose = pred_smpl_pose[:,3:66]
smplx_lhand_pose = pred_smpl_pose[:,66:111]
smplx_rhand_pose = pred_smpl_pose[:,111:156]
smplx_jaw_pose = pred_smpl_pose[:,156:]
results.append({
'scores': pred_scores,
'labels': pred_labels,
'boxes': pred_boxes,
'keypoints': pred_keypoints,
'smplx_root_pose': smplx_root_pose,
'smplx_body_pose': smplx_body_pose,
'smplx_lhand_pose': smplx_lhand_pose,
'smplx_rhand_pose': smplx_rhand_pose,
'smplx_jaw_pose': smplx_jaw_pose,
'smplx_shape': pred_smpl_beta,
'smplx_expr': pred_smpl_expr,
'cam_trans': pred_smpl_cam,
'smplx_mesh_cam': pred_smpl_verts,
'smplx_mesh_cam_target': gt_smpl_verts,
'gt_smpl_kp3d':gt_smpl_kp3d,
'smplx_joint_proj': pred_smpl_kp2d,
# 'image_idx': image_idx,
"img": data_batch_nc['img'].cpu().numpy(),
'bb2img_trans': gt_bb2img_trans,
'img_shape': img_wh_list,
'gt_ann_idx': gt_ann_idx
})
if self.nms_iou_threshold > 0:
raise NotImplementedError
item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)]
# import pdb; pdb.set_trace()
results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)]
else:
results = results
return results
class PostProcess_SMPLX_Multi_Infer(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(
self,
num_select=100,
nms_iou_threshold=-1,
num_body_points=17,
body_model= dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
num_betas=10,
gender='neutral',
keypoint_dst='smplx_137',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True)
) -> None:
super().__init__()
self.num_select = num_select
self.nms_iou_threshold = nms_iou_threshold
self.num_body_points=num_body_points
# -1 for neutral; 0 for male; 1 for femal
gender_body_model = {}
gender_body_model[-1] = build_body_model(body_model)
body_model['gender']='male'
gender_body_model[0] = build_body_model(body_model)
body_model['gender']='female'
gender_body_model[1] = build_body_model(body_model)
self.body_model = gender_body_model
@torch.no_grad()
def forward(self, outputs, target_sizes, targets, data_batch_nc, image_shape= None, not_to_xyxy=False, test=False):
"""
image_shape(target_sizes): input image shape
"""
# import pdb; pdb.set_trace()
batch_size = outputs['pred_keypoints'].shape[0]
results = []
device = outputs['pred_keypoints'].device
# for body_model in self.body_model.values():
# body_model.to(device)
pred_kp_coco = outputs['pred_keypoints']
num_select = self.num_select
out_logits, out_bbox= outputs['pred_logits'], outputs['pred_boxes']
out_body_bbox, out_lhand_bbox, out_rhand_bbox, out_face_bbox = \
outputs['pred_boxes'], outputs['pred_lhand_boxes'], \
outputs['pred_rhand_boxes'], outputs['pred_face_boxes']
out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \
outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \
outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts']
out_smpl_kp2d = []
for bs in range(batch_size):
out_kp3d_i = out_smpl_kp3d[bs]
out_cam_i = out_smpl_cam[bs]
out_img_shape = data_batch_nc['img_shape'][bs].flip(-1)[None]
out_kp2d_i = project_points_new(
points_3d=out_kp3d_i,
pred_cam=out_cam_i,
focal_length=5000,
camera_center=out_img_shape/2
)
out_smpl_kp2d.append(out_kp2d_i.detach().cpu().numpy())
out_smpl_kp2d = torch.tensor(out_smpl_kp2d).to(device)
# assert len(out_logits) == len(target_sizes)
# assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = \
torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1)
scores = topk_values
# bbox
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
if not_to_xyxy:
boxes = out_bbox
else:
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
out_body_bbox = box_ops.box_cxcywh_to_xyxy(out_body_bbox)
out_lhand_bbox = box_ops.box_cxcywh_to_xyxy(out_lhand_bbox)
out_rhand_bbox = box_ops.box_cxcywh_to_xyxy(out_rhand_bbox)
out_face_bbox = box_ops.box_cxcywh_to_xyxy(out_face_bbox)
# gather body bbox
target_sizes = target_sizes.type_as(boxes)
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
boxes = boxes_norm * scale_fct[:, None, :]
body_bbox_norm = torch.gather(out_body_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
body_boxes = body_bbox_norm * scale_fct[:, None, :]
lhand_bbox_norm = torch.gather(out_lhand_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
lhand_boxes = lhand_bbox_norm * scale_fct[:, None, :]
rhand_bbox_norm = torch.gather(out_rhand_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
rhand_boxes = rhand_bbox_norm * scale_fct[:, None, :]
face_bbox_norm = torch.gather(out_face_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
face_boxes = face_bbox_norm * scale_fct[:, None, :]
# from relative [0, 1] to absolute [0, height] coordinates
# smplx kp2d
topk_smpl_kp2d = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
pred_smpl_kp2d = torch.gather(
out_smpl_kp2d, 1,
topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 144, 2))
# pred_smpl_kp2d = np.concatenate(pred_smpl_kp2d, 0)
pred_kp_coco = pred_kp_coco[..., 0:17*2].reshape(pred_kp_coco.shape[0], pred_kp_coco.shape[1], 17, 2)
# pred_kp_coco_norm = torch.gather(
# pred_kp_coco, 1,
# topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 17, 2))
# pred_kp_coco = pred_kp_coco_norm * scale_fct[:, None, :2]
# smpl param
topk_smpl = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159))
smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3))
smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3))
smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3))
# smpl_verts = smpl_verts - smpl_kp3d[:,:, [0]]
(s, tx, ty) = (smpl_cam[..., 0] + 1e-9), smpl_cam[..., 1], smpl_cam[..., 2]
depth, dx, dy = 1./s, tx/s, ty/s
transl = torch.stack([dx, dy, depth], -1)
smplx_root_pose = smpl_pose[:, :, :3]
smplx_body_pose = smpl_pose[:, :, 3:66]
smplx_lhand_pose = smpl_pose[:, :, 66:111]
smplx_rhand_pose = smpl_pose[:, :, 111:156]
smplx_jaw_pose = smpl_pose[:, :, 156:]
if 'ann_idx' in data_batch_nc:
image_idx=[target.cpu().numpy()[0] for target in data_batch_nc['ann_idx']]
for bs in range(batch_size):
results.append({
'scores': scores[bs],
'labels': labels[bs],
'keypoints_coco': pred_kp_coco[bs],
'smpl_kp3d': smpl_kp3d[bs],
'smplx_root_pose': smplx_root_pose[bs],
'smplx_body_pose': smplx_body_pose[bs],
'smplx_lhand_pose': smplx_lhand_pose[bs],
'smplx_rhand_pose': smplx_rhand_pose[bs],
'smplx_jaw_pose': smplx_jaw_pose[bs],
'smplx_shape': smpl_beta[bs],
'smplx_expr': smpl_expr[bs],
'smplx_joint_proj': pred_smpl_kp2d[bs],
'smpl_verts': smpl_verts[bs],
'image_idx': image_idx[bs],
'cam_trans': transl[bs],
'body_bbox': body_boxes[bs],
'lhand_bbox': lhand_boxes[bs],
'rhand_bbox': rhand_boxes[bs],
'face_bbox': face_boxes[bs],
'bb2img_trans': data_batch_nc['bb2img_trans'][bs],
'img2bb_trans': data_batch_nc['img2bb_trans'][bs],
'img': data_batch_nc['img'][bs],
'img_shape': data_batch_nc['img_shape'][bs]
})
if self.nms_iou_threshold > 0:
raise NotImplementedError
item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)]
# import pdb; pdb.set_trace()
results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)]
else:
results = results
return results
class PostProcess_SMPLX_Multi_Box(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(
self,
num_select=100,
nms_iou_threshold=-1,
num_body_points=17,
body_model= dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
num_betas=10,
gender='neutral',
keypoint_dst='smplx_137',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True)
) -> None:
super().__init__()
self.num_select = num_select
self.nms_iou_threshold = nms_iou_threshold
self.num_body_points=num_body_points
# -1 for neutral; 0 for male; 1 for femal
gender_body_model = {}
gender_body_model[-1] = build_body_model(body_model)
body_model['gender']='male'
gender_body_model[0] = build_body_model(body_model)
body_model['gender']='female'
gender_body_model[1] = build_body_model(body_model)
self.body_model = gender_body_model
@torch.no_grad()
def forward(self, outputs, target_sizes, targets, data_batch_nc, not_to_xyxy=False, test=False):
# import pdb; pdb.set_trace()
batch_size = outputs['pred_smpl_beta'].shape[0]
results = []
device = outputs['pred_smpl_beta'].device
for body_model in self.body_model.values():
body_model.to(device)
# test with instance num
# num_select=data_batch_nc['joint_img'][0].shape[0]
num_select = self.num_select
out_logits, out_bbox= outputs['pred_logits'], outputs['pred_boxes']
out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \
outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \
outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts']
out_smpl_kp2d = []
for bs in range(batch_size):
out_kp3d_i = out_smpl_kp3d[bs]
out_cam_i = out_smpl_cam[bs]
out_img_shape = data_batch_nc['img_shape'][bs].flip(-1)[None]
# out_kp3d_i = out_kp3d_i - out_kp3d_i[:, [0]]
out_kp2d_i = project_points_new(
points_3d=out_kp3d_i,
pred_cam=out_cam_i,
focal_length=5000,
camera_center=out_img_shape/2
)
out_smpl_kp2d.append(out_kp2d_i.detach().cpu().numpy())
out_smpl_kp2d = torch.tensor(out_smpl_kp2d).to(device)
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = \
torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1)
scores = topk_values
# bbox
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
if not_to_xyxy:
boxes = out_bbox
else:
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
if test:
assert not not_to_xyxy
boxes[:,:,2:] = boxes[:,:,2:] - boxes[:,:,:2]
# gather gt bbox
boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
target_sizes = target_sizes.type_as(boxes)
# from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes_norm * scale_fct[:, None, :]
# smplx kp2d
topk_smpl_kp2d = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
pred_smpl_kp2d = torch.gather(
out_smpl_kp2d, 1,
topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 137, 2))
# smpl out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d
topk_smpl = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159))
smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3))
smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3))
smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3))
tgt_smpl_kp3d = data_batch_nc['joint_cam']
tgt_smpl_pose = data_batch_nc['smplx_pose']
tgt_smpl_beta = data_batch_nc['smplx_shape']
tgt_smpl_expr = data_batch_nc['smplx_expr']
tgt_keypoints = data_batch_nc['joint_img']
tgt_img_shape = data_batch_nc['img_shape']
tgt_bb2img_trans = data_batch_nc['bb2img_trans']
tgt_ann_idx = data_batch_nc['ann_idx']
pred_indice_list = []
gt_indice_list = []
tgt_verts = []
tgt_kp3d = []
tgt_bbox = []
for bbox_center, bbox_size, pose, \
beta, expr, gender, gt_kp2d, _, pred_kp2d, pred_kp3d, boxe, scale \
in zip(
data_batch_nc['body_bbox_center'],
data_batch_nc['body_bbox_size'],
data_batch_nc['smplx_pose'],
data_batch_nc['smplx_shape'],
data_batch_nc['smplx_expr'],
data_batch_nc['gender'],
data_batch_nc['joint_img'],
data_batch_nc['joint_cam'],
pred_smpl_kp2d, smpl_kp3d, boxes, scale_fct,
):
# build smplx verts
gt_verts = []
gt_kp3d = []
gt_bbox = []
gender_ = gender.cpu().numpy()
for i, g in enumerate(gender_):
gt_out = self.body_model[g](
betas=beta[i].reshape(-1, 10),
global_orient=pose[i, :3].reshape(-1, 3).unsqueeze(1),
body_pose=pose[i, 3:66].reshape(-1, 21 * 3),
left_hand_pose=pose[i, 66:111].reshape(-1, 15 * 3),
right_hand_pose=pose[i, 111:156].reshape(-1, 15 * 3),
jaw_pose=pose[i, 156:159].reshape(-1, 3),
leye_pose=torch.zeros_like(pose[i, 156:159]),
reye_pose=torch.zeros_like(pose[i, 156:159]),
expression=expr[i].reshape(-1, 10),
)
gt_verts.append(gt_out['vertices'][0].detach().cpu().numpy())
gt_kp3d.append(gt_out['joints'][0].detach().cpu().numpy())
tgt_verts.append(gt_verts)
tgt_kp3d.append(gt_kp3d)
# bbox
gt_bbox = torch.cat(
[bbox_center - bbox_size / 2, bbox_size ], dim=-1)
gt_bbox = gt_bbox * scale
# xywh2xyxy
gt_bbox[..., 2] = gt_bbox[..., 0] + gt_bbox[..., 2]
gt_bbox[..., 3] = gt_bbox[..., 1] + gt_bbox[..., 3]
tgt_bbox.append(gt_bbox[..., :4].float())
pred_bbox = boxe.clone()
# box_iou = box_ops.box_iou(pred_bbox,gt_bbox)[0]
cost_giou = -box_ops.generalized_box_iou(pred_bbox, gt_bbox)
cost_bbox = torch.cdist(
box_ops.box_xyxy_to_cxcywh(pred_bbox)/scale,
box_ops.box_xyxy_to_cxcywh(gt_bbox)/scale, p=1)
# smpl kp2d
gt_kp2d_conf = gt_kp2d[:,:,2:3]
gt_kp2d_ = (gt_kp2d[:, :, :2] * scale[:2]) /torch.tensor([12, 16]).to(device)
# gt_kp2d_conf, _ = convert_kps(gt_kp2d_conf,'smplx_137', 'coco', approximate=True)
# cost_keypoints = torch.abs(
# (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])*gt_kp2d_conf[None]
# ).sum([-2,-1])/gt_kp2d_conf[None].sum()
gt_kp2d_body, _ = convert_kps(gt_kp2d_,'smplx_137', 'coco', approximate=True)
pred_kp2d_body, _ = convert_kps(pred_kp2d,'smplx_137', 'coco', approximate=True)
cost_keypoints = torch.abs(
(pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])
).sum([-2,-1])
# cost_keypoints = torch.abs(
# (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])*gt_kp2d_body_conf[None]
# ).sum([-2,-1])/gt_kp2d_body_conf[None].sum()
# coco kp2d
# gt_kp2d_conf, _ = convert_kps(gt_kp2d_conf,'smplx_137', 'coco', approximate=True)
# keypoints_coco = Z_pred.reshape(num_select, 17,2)
# ubody
# cost_keypoints_coco = torch.abs(
# (keypoints_coco[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])*gt_kp2d_conf[None]
# ).sum([-2,-1])/gt_kp2d_conf[None].sum()
# others
# cost_keypoints_coco = torch.abs(
# (keypoints_coco[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])
# ).sum([-2,-1])
# smpl kp3d
gt_kp3d_ = torch.tensor(np.array(gt_kp3d) - np.array(gt_kp3d)[:, [0]]).to(device)
pred_kp3d_ = (pred_kp3d - pred_kp3d[:, [0]])
cost_kp3d = torch.abs((pred_kp3d_[:, None] - gt_kp3d_[None])).sum([-2,-1])
# 1. kps
indice = linear_sum_assignment(cost_keypoints.cpu())
# 2. bbox giou
# indice = linear_sum_assignment(cost_giou.cpu())
# 3. bbox
# indice = linear_sum_assignment(cost_bbox.cpu())
# 4. all
# indice = linear_sum_assignment(
# 10* (cost_keypoints).cpu() + 5 * cost_bbox.cpu())
# 5. kp3d
# indice = linear_sum_assignment(cost_kp3d.cpu())
# 5. kp2d coco
# indice = linear_sum_assignment(cost_keypoints_coco.cpu())
pred_ind, gt_ind = indice
pred_indice_list.append(pred_ind)
gt_indice_list.append(gt_ind)
pred_scores = torch.cat(
[t[i] for t, i in zip(scores, pred_indice_list)]
).detach().cpu().numpy()
pred_labels = torch.cat(
[t[i] for t, i in zip(labels, pred_indice_list)]
).detach().cpu().numpy()
pred_boxes = torch.cat(
[t[i] for t, i in zip(boxes, pred_indice_list)]
).detach().cpu().numpy()
# pred_keypoints = torch.cat(
# [t[i] for t, i in zip(keypoints_res, pred_indice_list)]
# ).detach().cpu().numpy()
pred_smpl_kp2d = []
pred_smpl_kp3d = []
pred_smpl_cam = []
img_wh_list = []
for i, img_wh in enumerate(tgt_img_shape):
kp3d = smpl_kp3d[i][pred_indice_list[i]]
cam = smpl_cam[i][pred_indice_list[i]]
img_wh = img_wh.flip(-1)[None]
kp2d = project_points_new(
points_3d=kp3d,
pred_cam=cam,
focal_length=5000,
camera_center=img_wh/2
)
num_instance = kp2d.shape[0]
img_wh_list.append(img_wh.repeat(num_instance,1).cpu().numpy())
pred_smpl_kp2d.append(kp2d.detach().cpu().numpy())
pred_smpl_kp3d.append(kp3d.detach().cpu().numpy())
pred_smpl_cam.append(cam.detach().cpu().numpy())
# pred_smpl_cam = torch.cat(
# [t[i] for t, i in zip(smpl_cam, pred_indice_list)]
# ).detach().cpu().numpy()
# pred_smpl_kp3d = torch.cat(
# [t[i] for t, i in zip(smpl_kp3d, pred_indice_list)]
# )
pred_smpl_pose = torch.cat(
[t[i] for t, i in zip(smpl_pose, pred_indice_list)]
).detach().cpu().numpy()
pred_smpl_beta = torch.cat(
[t[i] for t, i in zip(smpl_beta, pred_indice_list)]
).detach().cpu().numpy()
pred_smpl_expr = torch.cat(
[t[i] for t, i in zip(smpl_expr, pred_indice_list)]
).detach().cpu().numpy()
pred_smpl_verts = torch.cat(
[t[i] for t, i in zip(smpl_verts, pred_indice_list)]
).detach().cpu().numpy()
pred_smpl_kp2d = np.concatenate(pred_smpl_kp2d, 0)
pred_smpl_kp3d = np.concatenate(pred_smpl_kp3d, 0)
pred_smpl_cam = np.concatenate(pred_smpl_cam, 0)
img_wh_list = np.concatenate(img_wh_list, 0)
gt_smpl_kp3d = torch.cat(tgt_smpl_kp3d).detach().cpu().numpy()
gt_smpl_pose = torch.cat(tgt_smpl_pose).detach().cpu().numpy()
gt_smpl_beta = torch.cat(tgt_smpl_beta).detach().cpu().numpy()
gt_boxes = torch.cat(tgt_bbox).detach().cpu().numpy()
gt_smpl_expr = torch.cat(tgt_smpl_expr).detach().cpu().numpy()
# gt_img_shape = torch.cat(tgt_img_shape).detach().cpu().numpy()
gt_smpl_verts = np.concatenate(
[np.array(t)[i] for t, i in zip(tgt_verts, gt_indice_list)], 0)
gt_ann_idx = torch.cat([t.repeat(len(i)) for t, i in zip(tgt_ann_idx, gt_indice_list)],dim=0).cpu().numpy()
gt_keypoints = torch.cat(tgt_keypoints).detach().cpu().numpy()
# gt_img_shape = tgt_img_shape.detach().cpu().numpy()
gt_bb2img_trans = torch.stack(tgt_bb2img_trans).detach().cpu().numpy()
if 'image_id' in targets[i]:
image_idx=(targets[i]['image_id'].detach().cpu().numpy())
smplx_root_pose = pred_smpl_pose[:,:3]
smplx_body_pose = pred_smpl_pose[:,3:66]
smplx_lhand_pose = pred_smpl_pose[:,66:111]
smplx_rhand_pose = pred_smpl_pose[:,111:156]
smplx_jaw_pose = pred_smpl_pose[:,156:]
results.append({
'scores': pred_scores,
'labels': pred_labels,
'boxes': pred_boxes,
# 'keypoints': pred_keypoints,
'smplx_root_pose': smplx_root_pose,
'smplx_body_pose': smplx_body_pose,
'smplx_lhand_pose': smplx_lhand_pose,
'smplx_rhand_pose': smplx_rhand_pose,
'smplx_jaw_pose': smplx_jaw_pose,
'smplx_shape': pred_smpl_beta,
'smplx_expr': pred_smpl_expr,
'cam_trans': pred_smpl_cam,
'smplx_mesh_cam': pred_smpl_verts,
'smplx_mesh_cam_target': gt_smpl_verts,
'gt_smpl_kp3d':gt_smpl_kp3d,
'smplx_joint_proj': pred_smpl_kp2d,
# 'image_idx': image_idx,
"img": data_batch_nc['img'].cpu().numpy(),
'bb2img_trans': gt_bb2img_trans,
'img_shape': img_wh_list,
'gt_ann_idx': gt_ann_idx
})
if self.nms_iou_threshold > 0:
raise NotImplementedError
item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)]
# import pdb; pdb.set_trace()
results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)]
else:
results = results
return results
class PostProcess_SMPLX_Multi_Infer_Box(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(
self,
num_select=100,
nms_iou_threshold=-1,
num_body_points=17,
body_model= dict(
type='smplx',
keypoint_src='smplx',
num_expression_coeffs=10,
num_betas=10,
gender='neutral',
keypoint_dst='smplx_137',
model_path='data/body_models/smplx',
use_pca=False,
use_face_contour=True)
) -> None:
super().__init__()
self.num_select = num_select
self.nms_iou_threshold = nms_iou_threshold
self.num_body_points=num_body_points
# -1 for neutral; 0 for male; 1 for femal
gender_body_model = {}
gender_body_model[-1] = build_body_model(body_model)
body_model['gender']='male'
gender_body_model[0] = build_body_model(body_model)
body_model['gender']='female'
gender_body_model[1] = build_body_model(body_model)
self.body_model = gender_body_model
@torch.no_grad()
def forward(self, outputs, target_sizes, targets, data_batch_nc, image_shape= None, not_to_xyxy=False, test=False):
"""
image_shape(target_sizes): input image shape
"""
batch_size = outputs['pred_smpl_beta'].shape[0]
results = []
device = outputs['pred_smpl_beta'].device
num_select = self.num_select
out_logits, out_bbox= outputs['pred_logits'], outputs['pred_boxes']
out_body_bbox, out_lhand_bbox, out_rhand_bbox, out_face_bbox = \
outputs['pred_boxes'], outputs['pred_lhand_boxes'], \
outputs['pred_rhand_boxes'], outputs['pred_face_boxes']
out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \
outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \
outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts']
out_smpl_kp2d = []
for bs in range(batch_size):
out_kp3d_i = out_smpl_kp3d[bs]
out_cam_i = out_smpl_cam[bs]
out_img_shape = data_batch_nc['img_shape'][bs].flip(-1)[None]
out_kp2d_i = project_points_new(
points_3d=out_kp3d_i,
pred_cam=out_cam_i,
focal_length=5000,
camera_center=out_img_shape/2
)
out_smpl_kp2d.append(out_kp2d_i.detach().cpu().numpy())
out_smpl_kp2d = torch.tensor(out_smpl_kp2d).to(device)
prob = out_logits.sigmoid()
topk_values, topk_indexes = \
torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1)
scores = topk_values
# bbox
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
if not_to_xyxy:
boxes = out_bbox
else:
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
out_body_bbox = box_ops.box_cxcywh_to_xyxy(out_body_bbox)
out_lhand_bbox = box_ops.box_cxcywh_to_xyxy(out_lhand_bbox)
out_rhand_bbox = box_ops.box_cxcywh_to_xyxy(out_rhand_bbox)
out_face_bbox = box_ops.box_cxcywh_to_xyxy(out_face_bbox)
# gather body bbox
target_sizes = target_sizes.type_as(boxes)
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
boxes = boxes_norm * scale_fct[:, None, :]
body_bbox_norm = torch.gather(out_body_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
body_boxes = body_bbox_norm * scale_fct[:, None, :]
lhand_bbox_norm = torch.gather(out_lhand_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
lhand_boxes = lhand_bbox_norm * scale_fct[:, None, :]
rhand_bbox_norm = torch.gather(out_rhand_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
rhand_boxes = rhand_bbox_norm * scale_fct[:, None, :]
face_bbox_norm = torch.gather(out_face_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
face_boxes = face_bbox_norm * scale_fct[:, None, :]
# from relative [0, 1] to absolute [0, height] coordinates
# smplx kp2d
topk_smpl_kp2d = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
pred_smpl_kp2d = torch.gather(
out_smpl_kp2d, 1,
topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 144, 2))
# smpl param
topk_smpl = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159))
smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10))
smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3))
smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3))
smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3))
# smpl_verts = smpl_verts - smpl_kp3d[:,:, [0]]
(s, tx, ty) = (smpl_cam[..., 0] + 1e-9), smpl_cam[..., 1], smpl_cam[..., 2]
depth, dx, dy = 1./s, tx/s, ty/s
transl = torch.stack([dx, dy, depth], -1)
smplx_root_pose = smpl_pose[:, :, :3]
smplx_body_pose = smpl_pose[:, :, 3:66]
smplx_lhand_pose = smpl_pose[:, :, 66:111]
smplx_rhand_pose = smpl_pose[:, :, 111:156]
smplx_jaw_pose = smpl_pose[:, :, 156:]
if 'ann_idx' in data_batch_nc:
image_idx=[target.cpu().numpy()[0] for target in data_batch_nc['ann_idx']]
for bs in range(batch_size):
results.append({
'scores': scores[bs],
'labels': labels[bs],
'smpl_kp3d': smpl_kp3d[bs],
'smplx_root_pose': smplx_root_pose[bs],
'smplx_body_pose': smplx_body_pose[bs],
'smplx_lhand_pose': smplx_lhand_pose[bs],
'smplx_rhand_pose': smplx_rhand_pose[bs],
'smplx_jaw_pose': smplx_jaw_pose[bs],
'smplx_shape': smpl_beta[bs],
'smplx_expr': smpl_expr[bs],
'smplx_joint_proj': pred_smpl_kp2d[bs],
'smpl_verts': smpl_verts[bs],
'image_idx': image_idx[bs],
'cam_trans': transl[bs],
'body_bbox': body_boxes[bs],
'lhand_bbox': lhand_boxes[bs],
'rhand_bbox': rhand_boxes[bs],
'face_bbox': face_boxes[bs],
'bb2img_trans': data_batch_nc['bb2img_trans'][bs],
'img2bb_trans': data_batch_nc['img2bb_trans'][bs],
'img': data_batch_nc['img'][bs],
'img_shape': data_batch_nc['img_shape'][bs]
})
if self.nms_iou_threshold > 0:
raise NotImplementedError
item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)]
# import pdb; pdb.set_trace()
results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)]
else:
results = results
return results