AiOS / detrsmpl /models /heads /hybrik_head.py
ttxskk
update
d7e58f0
import numpy as np
import torch
import torch.cuda.comm
import torch.nn as nn
from mmcv.runner.base_module import BaseModule
from torch.nn import functional as F
from detrsmpl.core.conventions.keypoints_mapping import get_flip_pairs
def norm_heatmap(norm_type, heatmap):
"""Normalize heatmap.
Args:
norm_type (str):
type of normalization. Currently only 'softmax' is supported
heatmap (torch.Tensor):
model output heatmap with shape (Bx29xF^2) where F^2 refers to
number of squared feature channels F
Returns:
heatmap (torch.Tensor):
normalized heatmap according to specified type with
shape (Bx29xF^2)
"""
# Input tensor shape: [N,C,...]
shape = heatmap.shape
if norm_type == 'softmax':
heatmap = heatmap.reshape(*shape[:2], -1)
# global soft max
heatmap = F.softmax(heatmap, 2)
return heatmap.reshape(*shape)
else:
raise NotImplementedError
class HybrIKHead(BaseModule):
"""HybrIK parameters regressor head.
Args:
feature_channel (int):
Number of input channels
deconv_dim (List[int]):
List of deconvolution dimensions
num_joints (int):
Number of keypoints
depth_dim (int):
Depth dimension
height_dim (int):
Height dimension
width_dim (int):
Width dimension
smpl_mean_params (str):
file name of the mean SMPL parameters
"""
def __init__(
self,
feature_channel=512,
deconv_dim=[256, 256, 256],
num_joints=29,
depth_dim=64,
height_dim=64,
width_dim=64,
smpl_mean_params=None,
):
super(HybrIKHead, self).__init__()
self.deconv_dim = deconv_dim
self._norm_layer = nn.BatchNorm2d
self.num_joints = num_joints
self.norm_type = 'softmax'
self.depth_dim = depth_dim
self.height_dim = height_dim
self.width_dim = width_dim
self.smpl_dtype = torch.float32
self.feature_channel = feature_channel
self.deconv_layers = self._make_deconv_layer()
self.final_layer = nn.Conv2d(self.deconv_dim[2],
self.num_joints * self.depth_dim,
kernel_size=1,
stride=1,
padding=0)
self.joint_pairs_24 = get_flip_pairs('smpl')
self.joint_pairs_29 = get_flip_pairs('hybrik_29')
self.leaf_pairs = ((0, 1), (3, 4))
self.root_idx_smpl = 0
# mean shape
init_shape = np.load(smpl_mean_params)
self.register_buffer('init_shape', torch.Tensor(init_shape).float())
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(self.feature_channel, 1024)
self.drop1 = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(1024, 1024)
self.drop2 = nn.Dropout(p=0.5)
self.decshape = nn.Linear(1024, 10)
self.decphi = nn.Linear(1024, 23 * 2) # [cos(phi), sin(phi)]
def _make_deconv_layer(self):
deconv_layers = []
deconv1 = nn.ConvTranspose2d(self.feature_channel,
self.deconv_dim[0],
kernel_size=4,
stride=2,
padding=int(4 / 2) - 1,
bias=False)
bn1 = self._norm_layer(self.deconv_dim[0])
deconv2 = nn.ConvTranspose2d(self.deconv_dim[0],
self.deconv_dim[1],
kernel_size=4,
stride=2,
padding=int(4 / 2) - 1,
bias=False)
bn2 = self._norm_layer(self.deconv_dim[1])
deconv3 = nn.ConvTranspose2d(self.deconv_dim[1],
self.deconv_dim[2],
kernel_size=4,
stride=2,
padding=int(4 / 2) - 1,
bias=False)
bn3 = self._norm_layer(self.deconv_dim[2])
deconv_layers.append(deconv1)
deconv_layers.append(bn1)
deconv_layers.append(nn.ReLU(inplace=True))
deconv_layers.append(deconv2)
deconv_layers.append(bn2)
deconv_layers.append(nn.ReLU(inplace=True))
deconv_layers.append(deconv3)
deconv_layers.append(bn3)
deconv_layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*deconv_layers)
def _initialize(self):
for name, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
for m in self.final_layer.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.001)
nn.init.constant_(m.bias, 0)
def uvd_to_cam(self,
uvd_jts,
trans_inv,
intrinsic_param,
joint_root,
depth_factor,
return_relative=True):
"""Project uvd coordinates to camera frame.
Args:
uvd_jts (torch.Tensor):
uvd coordinates with shape (BxNum_jointsx3)
trans_inv (torch.Tensor):
inverse affine transformation matrix with shape (Bx2x3)
intrinsic_param (torch.Tensor):
camera intrinsic matrix with shape (Bx3x3)
joint_root (torch.Tensor):
root joint coordinate with shape (Bx3)
depth_factor (float):
depth factor with shape (Bx1)
return_relative (bool):
Store True to return root normalized relative coordinates.
Default: True.
Returns:
xyz_jts (torch.Tensor):
uvd coordinates in camera frame with shape (BxNum_jointsx3)
"""
assert uvd_jts.dim() == 3 and uvd_jts.shape[2] == 3, uvd_jts.shape
uvd_jts_new = uvd_jts.clone()
# if torch.sum(torch.isnan(uvd_jts)) > 0:
# aaa= 1
assert torch.sum(torch.isnan(uvd_jts)) == 0, ('uvd_jts', uvd_jts)
# remap uv coordinate to input space
uvd_jts_new[:, :, 0] = (uvd_jts[:, :, 0] + 0.5) * self.width_dim * 4
uvd_jts_new[:, :, 1] = (uvd_jts[:, :, 1] + 0.5) * self.height_dim * 4
# remap d to mm
uvd_jts_new[:, :, 2] = uvd_jts[:, :, 2] * depth_factor
assert torch.sum(torch.isnan(uvd_jts_new)) == 0, ('uvd_jts_new',
uvd_jts_new)
dz = uvd_jts_new[:, :, 2]
# transform in-bbox coordinate to image coordinate
uv_homo_jts = torch.cat(
(uvd_jts_new[:, :, :2], torch.ones_like(uvd_jts_new)[:, :, 2:]),
dim=2)
# batch-wise matrix multiply : (B,1,2,3) * (B,K,3,1) -> (B,K,2,1)
uv_jts = torch.matmul(trans_inv.unsqueeze(1),
uv_homo_jts.unsqueeze(-1))
# transform (u,v,1) to (x,y,z)
cam_2d_homo = torch.cat((uv_jts, torch.ones_like(uv_jts)[:, :, :1, :]),
dim=2)
# batch-wise matrix multiply : (B,1,3,3) * (B,K,3,1) -> (B,K,3,1)
xyz_jts = torch.matmul(intrinsic_param.unsqueeze(1), cam_2d_homo)
xyz_jts = xyz_jts.squeeze(dim=3)
# recover absolute z : (B,K) + (B,1)
abs_z = dz + joint_root[:, 2].unsqueeze(-1)
# multiply absolute z : (B,K,3) * (B,K,1)
xyz_jts = xyz_jts * abs_z.unsqueeze(-1)
if return_relative:
# (B,K,3) - (B,1,3)
xyz_jts = xyz_jts - joint_root.unsqueeze(1)
xyz_jts = xyz_jts / depth_factor.unsqueeze(-1)
return xyz_jts
def flip_uvd_coord(self, pred_jts, flip=False, flatten=True):
"""Flip uvd coordinates.
Args:
pred_jts (torch.Tensor):
predicted uvd coordinates with shape (Bx87)
flip (bool):
Store True to flip uvd coordinates. Default: False.
flatten (bool):
Store True to reshape uvd_coordinates to shape (Bx29x3)
Default: True
Returns:
pred_jts (torch.Tensor):
flipped uvd coordinates with shape (Bx29x3)
"""
if flatten:
assert pred_jts.dim() == 2
num_batches = pred_jts.shape[0]
pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3)
else:
assert pred_jts.dim() == 3
num_batches = pred_jts.shape[0]
# flip
if flip:
pred_jts[:, :, 0] = -pred_jts[:, :, 0]
else:
pred_jts[:, :, 0] = -1 / self.width_dim - pred_jts[:, :, 0]
for pair in self.joint_pairs_29:
dim0, dim1 = pair
idx = torch.Tensor((dim0, dim1)).long()
inv_idx = torch.Tensor((dim1, dim0)).long()
pred_jts[:, idx] = pred_jts[:, inv_idx]
return pred_jts
def flip_phi(self, pred_phi):
"""Flip phi.
Args:
pred_phi (torch.Tensor): phi in shape (Num_twistx2)
Returns:
pred_phi (torch.Tensor): flipped phi in shape (Num_twistx2)
"""
pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1]
for pair in self.joint_pairs_24:
dim0, dim1 = pair
idx = torch.Tensor((dim0 - 1, dim1 - 1)).long()
inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long()
pred_phi[:, idx] = pred_phi[:, inv_idx]
return pred_phi
def forward(self,
feature,
trans_inv,
intrinsic_param,
joint_root,
depth_factor,
smpl_layer,
flip_item=None,
flip_output=False):
"""Forward function.
Args:
feature (torch.Tensor): features extracted from backbone
trans_inv (torch.Tensor):
inverse affine transformation matrix with shape (Bx2x3)
intrinsic_param (torch.Tensor):
camera intrinsic matrix with shape (Bx3x3)
joint_root (torch.Tensor):
root joint coordinate with shape (Bx3)
depth_factor (float):
depth factor with shape (Bx1)
smpl_layer (torch.Tensor):
smpl body model
flip_item (List[torch.Tensor]|None):
list containing items to flip
flip_output (bool):
Store True to flip output. Default: False
Returns:
output (dict): Dict containing model predictions.
"""
batch_size = feature.shape[0]
x0 = feature
out = self.deconv_layers(x0)
out = self.final_layer(out)
out = out.reshape((out.shape[0], self.num_joints, -1))
out = norm_heatmap(self.norm_type, out)
assert out.dim() == 3, out.shape
if self.norm_type == 'sigmoid':
maxvals, _ = torch.max(out, dim=2, keepdim=True)
else:
maxvals = torch.ones((*out.shape[:2], 1),
dtype=torch.float,
device=out.device)
heatmaps = out / out.sum(dim=2, keepdim=True)
heatmaps = heatmaps.reshape(
(heatmaps.shape[0], self.num_joints, self.depth_dim,
self.height_dim, self.width_dim))
hm_x = heatmaps.sum((2, 3))
hm_y = heatmaps.sum((2, 4))
hm_z = heatmaps.sum((3, 4))
hm_x = hm_x * torch.cuda.comm.broadcast(torch.arange(
hm_x.shape[-1]).type(torch.cuda.FloatTensor),
devices=[hm_x.device.index])[0]
hm_y = hm_y * torch.cuda.comm.broadcast(torch.arange(
hm_y.shape[-1]).type(torch.cuda.FloatTensor),
devices=[hm_y.device.index])[0]
hm_z = hm_z * torch.cuda.comm.broadcast(torch.arange(
hm_z.shape[-1]).type(torch.cuda.FloatTensor),
devices=[hm_z.device.index])[0]
coord_x = hm_x.sum(dim=2, keepdim=True)
coord_y = hm_y.sum(dim=2, keepdim=True)
coord_z = hm_z.sum(dim=2, keepdim=True)
coord_x = coord_x / float(self.width_dim) - 0.5
coord_y = coord_y / float(self.height_dim) - 0.5
coord_z = coord_z / float(self.depth_dim) - 0.5
# -0.5 ~ 0.5
pred_uvd_jts_29 = torch.cat((coord_x, coord_y, coord_z), dim=2)
pred_uvd_jts_29_flat = pred_uvd_jts_29.reshape(
(batch_size, self.num_joints * 3))
x0 = self.avg_pool(x0)
x0 = x0.view(x0.size(0), -1)
init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,)
xc = x0
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
delta_shape = self.decshape(xc)
pred_shape = delta_shape + init_shape
pred_phi = self.decphi(xc)
if flip_item is not None:
assert flip_output
pred_uvd_jts_29_orig, pred_phi_orig, pred_leaf_orig, \
pred_shape_orig = flip_item
if flip_output:
pred_uvd_jts_29 = self.flip_uvd_coord(pred_uvd_jts_29,
flatten=False,
shift=True)
if flip_output and flip_item is not None:
pred_uvd_jts_29 = (pred_uvd_jts_29 + pred_uvd_jts_29_orig.reshape(
batch_size, 29, 3)) / 2
pred_uvd_jts_29_flat = pred_uvd_jts_29.reshape(
(batch_size, self.num_joints * 3))
# -0.5 ~ 0.5
# Rotate back
pred_xyz_jts_29 = self.uvd_to_cam(pred_uvd_jts_29, trans_inv,
intrinsic_param, joint_root,
depth_factor)
assert torch.sum(
torch.isnan(pred_xyz_jts_29)) == 0, ('pred_xyz_jts_29',
pred_xyz_jts_29)
pred_xyz_jts_29 = pred_xyz_jts_29 - \
pred_xyz_jts_29[:, self.root_idx_smpl, :].unsqueeze(1)
pred_phi = pred_phi.reshape(batch_size, 23, 2)
if flip_output:
pred_phi = self.flip_phi(pred_phi)
if flip_output and flip_item is not None:
pred_phi = (pred_phi + pred_phi_orig) / 2
pred_shape = (pred_shape + pred_shape_orig) / 2
hybrik_output = smpl_layer(
pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * 2,
betas=pred_shape.type(self.smpl_dtype),
phis=pred_phi.type(self.smpl_dtype),
global_orient=None,
return_verts=True)
pred_vertices = hybrik_output['vertices'].float()
# -0.5 ~ 0.5
pred_xyz_jts_24_struct = hybrik_output['joints'].float() / 2
# -0.5 ~ 0.5
pred_xyz_jts_17 = hybrik_output['joints_from_verts'].float() / 2
pred_poses = hybrik_output['poses'].float().reshape(
batch_size, 24, 3, 3)
pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, 72)
pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72)
pred_xyz_jts_17 = pred_xyz_jts_17.reshape(batch_size, 17 * 3)
output = {
'pred_phi': pred_phi,
'pred_delta_shape': delta_shape,
'pred_shape': pred_shape,
'pred_pose': pred_poses,
'pred_uvd_jts': pred_uvd_jts_29_flat,
'pred_xyz_jts_24': pred_xyz_jts_24,
'pred_xyz_jts_24_struct': pred_xyz_jts_24_struct,
'pred_xyz_jts_17': pred_xyz_jts_17,
'pred_vertices': pred_vertices,
'maxvals': maxvals,
}
return output