Spaces:
Running
on
L40S
Running
on
L40S
import numpy as np | |
import torch | |
import torch.cuda.comm | |
import torch.nn as nn | |
from mmcv.runner.base_module import BaseModule | |
from torch.nn import functional as F | |
from detrsmpl.core.conventions.keypoints_mapping import get_flip_pairs | |
def norm_heatmap(norm_type, heatmap): | |
"""Normalize heatmap. | |
Args: | |
norm_type (str): | |
type of normalization. Currently only 'softmax' is supported | |
heatmap (torch.Tensor): | |
model output heatmap with shape (Bx29xF^2) where F^2 refers to | |
number of squared feature channels F | |
Returns: | |
heatmap (torch.Tensor): | |
normalized heatmap according to specified type with | |
shape (Bx29xF^2) | |
""" | |
# Input tensor shape: [N,C,...] | |
shape = heatmap.shape | |
if norm_type == 'softmax': | |
heatmap = heatmap.reshape(*shape[:2], -1) | |
# global soft max | |
heatmap = F.softmax(heatmap, 2) | |
return heatmap.reshape(*shape) | |
else: | |
raise NotImplementedError | |
class HybrIKHead(BaseModule): | |
"""HybrIK parameters regressor head. | |
Args: | |
feature_channel (int): | |
Number of input channels | |
deconv_dim (List[int]): | |
List of deconvolution dimensions | |
num_joints (int): | |
Number of keypoints | |
depth_dim (int): | |
Depth dimension | |
height_dim (int): | |
Height dimension | |
width_dim (int): | |
Width dimension | |
smpl_mean_params (str): | |
file name of the mean SMPL parameters | |
""" | |
def __init__( | |
self, | |
feature_channel=512, | |
deconv_dim=[256, 256, 256], | |
num_joints=29, | |
depth_dim=64, | |
height_dim=64, | |
width_dim=64, | |
smpl_mean_params=None, | |
): | |
super(HybrIKHead, self).__init__() | |
self.deconv_dim = deconv_dim | |
self._norm_layer = nn.BatchNorm2d | |
self.num_joints = num_joints | |
self.norm_type = 'softmax' | |
self.depth_dim = depth_dim | |
self.height_dim = height_dim | |
self.width_dim = width_dim | |
self.smpl_dtype = torch.float32 | |
self.feature_channel = feature_channel | |
self.deconv_layers = self._make_deconv_layer() | |
self.final_layer = nn.Conv2d(self.deconv_dim[2], | |
self.num_joints * self.depth_dim, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.joint_pairs_24 = get_flip_pairs('smpl') | |
self.joint_pairs_29 = get_flip_pairs('hybrik_29') | |
self.leaf_pairs = ((0, 1), (3, 4)) | |
self.root_idx_smpl = 0 | |
# mean shape | |
init_shape = np.load(smpl_mean_params) | |
self.register_buffer('init_shape', torch.Tensor(init_shape).float()) | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.fc1 = nn.Linear(self.feature_channel, 1024) | |
self.drop1 = nn.Dropout(p=0.5) | |
self.fc2 = nn.Linear(1024, 1024) | |
self.drop2 = nn.Dropout(p=0.5) | |
self.decshape = nn.Linear(1024, 10) | |
self.decphi = nn.Linear(1024, 23 * 2) # [cos(phi), sin(phi)] | |
def _make_deconv_layer(self): | |
deconv_layers = [] | |
deconv1 = nn.ConvTranspose2d(self.feature_channel, | |
self.deconv_dim[0], | |
kernel_size=4, | |
stride=2, | |
padding=int(4 / 2) - 1, | |
bias=False) | |
bn1 = self._norm_layer(self.deconv_dim[0]) | |
deconv2 = nn.ConvTranspose2d(self.deconv_dim[0], | |
self.deconv_dim[1], | |
kernel_size=4, | |
stride=2, | |
padding=int(4 / 2) - 1, | |
bias=False) | |
bn2 = self._norm_layer(self.deconv_dim[1]) | |
deconv3 = nn.ConvTranspose2d(self.deconv_dim[1], | |
self.deconv_dim[2], | |
kernel_size=4, | |
stride=2, | |
padding=int(4 / 2) - 1, | |
bias=False) | |
bn3 = self._norm_layer(self.deconv_dim[2]) | |
deconv_layers.append(deconv1) | |
deconv_layers.append(bn1) | |
deconv_layers.append(nn.ReLU(inplace=True)) | |
deconv_layers.append(deconv2) | |
deconv_layers.append(bn2) | |
deconv_layers.append(nn.ReLU(inplace=True)) | |
deconv_layers.append(deconv3) | |
deconv_layers.append(bn3) | |
deconv_layers.append(nn.ReLU(inplace=True)) | |
return nn.Sequential(*deconv_layers) | |
def _initialize(self): | |
for name, m in self.deconv_layers.named_modules(): | |
if isinstance(m, nn.ConvTranspose2d): | |
nn.init.normal_(m.weight, std=0.001) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
for m in self.final_layer.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.normal_(m.weight, std=0.001) | |
nn.init.constant_(m.bias, 0) | |
def uvd_to_cam(self, | |
uvd_jts, | |
trans_inv, | |
intrinsic_param, | |
joint_root, | |
depth_factor, | |
return_relative=True): | |
"""Project uvd coordinates to camera frame. | |
Args: | |
uvd_jts (torch.Tensor): | |
uvd coordinates with shape (BxNum_jointsx3) | |
trans_inv (torch.Tensor): | |
inverse affine transformation matrix with shape (Bx2x3) | |
intrinsic_param (torch.Tensor): | |
camera intrinsic matrix with shape (Bx3x3) | |
joint_root (torch.Tensor): | |
root joint coordinate with shape (Bx3) | |
depth_factor (float): | |
depth factor with shape (Bx1) | |
return_relative (bool): | |
Store True to return root normalized relative coordinates. | |
Default: True. | |
Returns: | |
xyz_jts (torch.Tensor): | |
uvd coordinates in camera frame with shape (BxNum_jointsx3) | |
""" | |
assert uvd_jts.dim() == 3 and uvd_jts.shape[2] == 3, uvd_jts.shape | |
uvd_jts_new = uvd_jts.clone() | |
# if torch.sum(torch.isnan(uvd_jts)) > 0: | |
# aaa= 1 | |
assert torch.sum(torch.isnan(uvd_jts)) == 0, ('uvd_jts', uvd_jts) | |
# remap uv coordinate to input space | |
uvd_jts_new[:, :, 0] = (uvd_jts[:, :, 0] + 0.5) * self.width_dim * 4 | |
uvd_jts_new[:, :, 1] = (uvd_jts[:, :, 1] + 0.5) * self.height_dim * 4 | |
# remap d to mm | |
uvd_jts_new[:, :, 2] = uvd_jts[:, :, 2] * depth_factor | |
assert torch.sum(torch.isnan(uvd_jts_new)) == 0, ('uvd_jts_new', | |
uvd_jts_new) | |
dz = uvd_jts_new[:, :, 2] | |
# transform in-bbox coordinate to image coordinate | |
uv_homo_jts = torch.cat( | |
(uvd_jts_new[:, :, :2], torch.ones_like(uvd_jts_new)[:, :, 2:]), | |
dim=2) | |
# batch-wise matrix multiply : (B,1,2,3) * (B,K,3,1) -> (B,K,2,1) | |
uv_jts = torch.matmul(trans_inv.unsqueeze(1), | |
uv_homo_jts.unsqueeze(-1)) | |
# transform (u,v,1) to (x,y,z) | |
cam_2d_homo = torch.cat((uv_jts, torch.ones_like(uv_jts)[:, :, :1, :]), | |
dim=2) | |
# batch-wise matrix multiply : (B,1,3,3) * (B,K,3,1) -> (B,K,3,1) | |
xyz_jts = torch.matmul(intrinsic_param.unsqueeze(1), cam_2d_homo) | |
xyz_jts = xyz_jts.squeeze(dim=3) | |
# recover absolute z : (B,K) + (B,1) | |
abs_z = dz + joint_root[:, 2].unsqueeze(-1) | |
# multiply absolute z : (B,K,3) * (B,K,1) | |
xyz_jts = xyz_jts * abs_z.unsqueeze(-1) | |
if return_relative: | |
# (B,K,3) - (B,1,3) | |
xyz_jts = xyz_jts - joint_root.unsqueeze(1) | |
xyz_jts = xyz_jts / depth_factor.unsqueeze(-1) | |
return xyz_jts | |
def flip_uvd_coord(self, pred_jts, flip=False, flatten=True): | |
"""Flip uvd coordinates. | |
Args: | |
pred_jts (torch.Tensor): | |
predicted uvd coordinates with shape (Bx87) | |
flip (bool): | |
Store True to flip uvd coordinates. Default: False. | |
flatten (bool): | |
Store True to reshape uvd_coordinates to shape (Bx29x3) | |
Default: True | |
Returns: | |
pred_jts (torch.Tensor): | |
flipped uvd coordinates with shape (Bx29x3) | |
""" | |
if flatten: | |
assert pred_jts.dim() == 2 | |
num_batches = pred_jts.shape[0] | |
pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) | |
else: | |
assert pred_jts.dim() == 3 | |
num_batches = pred_jts.shape[0] | |
# flip | |
if flip: | |
pred_jts[:, :, 0] = -pred_jts[:, :, 0] | |
else: | |
pred_jts[:, :, 0] = -1 / self.width_dim - pred_jts[:, :, 0] | |
for pair in self.joint_pairs_29: | |
dim0, dim1 = pair | |
idx = torch.Tensor((dim0, dim1)).long() | |
inv_idx = torch.Tensor((dim1, dim0)).long() | |
pred_jts[:, idx] = pred_jts[:, inv_idx] | |
return pred_jts | |
def flip_phi(self, pred_phi): | |
"""Flip phi. | |
Args: | |
pred_phi (torch.Tensor): phi in shape (Num_twistx2) | |
Returns: | |
pred_phi (torch.Tensor): flipped phi in shape (Num_twistx2) | |
""" | |
pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1] | |
for pair in self.joint_pairs_24: | |
dim0, dim1 = pair | |
idx = torch.Tensor((dim0 - 1, dim1 - 1)).long() | |
inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long() | |
pred_phi[:, idx] = pred_phi[:, inv_idx] | |
return pred_phi | |
def forward(self, | |
feature, | |
trans_inv, | |
intrinsic_param, | |
joint_root, | |
depth_factor, | |
smpl_layer, | |
flip_item=None, | |
flip_output=False): | |
"""Forward function. | |
Args: | |
feature (torch.Tensor): features extracted from backbone | |
trans_inv (torch.Tensor): | |
inverse affine transformation matrix with shape (Bx2x3) | |
intrinsic_param (torch.Tensor): | |
camera intrinsic matrix with shape (Bx3x3) | |
joint_root (torch.Tensor): | |
root joint coordinate with shape (Bx3) | |
depth_factor (float): | |
depth factor with shape (Bx1) | |
smpl_layer (torch.Tensor): | |
smpl body model | |
flip_item (List[torch.Tensor]|None): | |
list containing items to flip | |
flip_output (bool): | |
Store True to flip output. Default: False | |
Returns: | |
output (dict): Dict containing model predictions. | |
""" | |
batch_size = feature.shape[0] | |
x0 = feature | |
out = self.deconv_layers(x0) | |
out = self.final_layer(out) | |
out = out.reshape((out.shape[0], self.num_joints, -1)) | |
out = norm_heatmap(self.norm_type, out) | |
assert out.dim() == 3, out.shape | |
if self.norm_type == 'sigmoid': | |
maxvals, _ = torch.max(out, dim=2, keepdim=True) | |
else: | |
maxvals = torch.ones((*out.shape[:2], 1), | |
dtype=torch.float, | |
device=out.device) | |
heatmaps = out / out.sum(dim=2, keepdim=True) | |
heatmaps = heatmaps.reshape( | |
(heatmaps.shape[0], self.num_joints, self.depth_dim, | |
self.height_dim, self.width_dim)) | |
hm_x = heatmaps.sum((2, 3)) | |
hm_y = heatmaps.sum((2, 4)) | |
hm_z = heatmaps.sum((3, 4)) | |
hm_x = hm_x * torch.cuda.comm.broadcast(torch.arange( | |
hm_x.shape[-1]).type(torch.cuda.FloatTensor), | |
devices=[hm_x.device.index])[0] | |
hm_y = hm_y * torch.cuda.comm.broadcast(torch.arange( | |
hm_y.shape[-1]).type(torch.cuda.FloatTensor), | |
devices=[hm_y.device.index])[0] | |
hm_z = hm_z * torch.cuda.comm.broadcast(torch.arange( | |
hm_z.shape[-1]).type(torch.cuda.FloatTensor), | |
devices=[hm_z.device.index])[0] | |
coord_x = hm_x.sum(dim=2, keepdim=True) | |
coord_y = hm_y.sum(dim=2, keepdim=True) | |
coord_z = hm_z.sum(dim=2, keepdim=True) | |
coord_x = coord_x / float(self.width_dim) - 0.5 | |
coord_y = coord_y / float(self.height_dim) - 0.5 | |
coord_z = coord_z / float(self.depth_dim) - 0.5 | |
# -0.5 ~ 0.5 | |
pred_uvd_jts_29 = torch.cat((coord_x, coord_y, coord_z), dim=2) | |
pred_uvd_jts_29_flat = pred_uvd_jts_29.reshape( | |
(batch_size, self.num_joints * 3)) | |
x0 = self.avg_pool(x0) | |
x0 = x0.view(x0.size(0), -1) | |
init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,) | |
xc = x0 | |
xc = self.fc1(xc) | |
xc = self.drop1(xc) | |
xc = self.fc2(xc) | |
xc = self.drop2(xc) | |
delta_shape = self.decshape(xc) | |
pred_shape = delta_shape + init_shape | |
pred_phi = self.decphi(xc) | |
if flip_item is not None: | |
assert flip_output | |
pred_uvd_jts_29_orig, pred_phi_orig, pred_leaf_orig, \ | |
pred_shape_orig = flip_item | |
if flip_output: | |
pred_uvd_jts_29 = self.flip_uvd_coord(pred_uvd_jts_29, | |
flatten=False, | |
shift=True) | |
if flip_output and flip_item is not None: | |
pred_uvd_jts_29 = (pred_uvd_jts_29 + pred_uvd_jts_29_orig.reshape( | |
batch_size, 29, 3)) / 2 | |
pred_uvd_jts_29_flat = pred_uvd_jts_29.reshape( | |
(batch_size, self.num_joints * 3)) | |
# -0.5 ~ 0.5 | |
# Rotate back | |
pred_xyz_jts_29 = self.uvd_to_cam(pred_uvd_jts_29, trans_inv, | |
intrinsic_param, joint_root, | |
depth_factor) | |
assert torch.sum( | |
torch.isnan(pred_xyz_jts_29)) == 0, ('pred_xyz_jts_29', | |
pred_xyz_jts_29) | |
pred_xyz_jts_29 = pred_xyz_jts_29 - \ | |
pred_xyz_jts_29[:, self.root_idx_smpl, :].unsqueeze(1) | |
pred_phi = pred_phi.reshape(batch_size, 23, 2) | |
if flip_output: | |
pred_phi = self.flip_phi(pred_phi) | |
if flip_output and flip_item is not None: | |
pred_phi = (pred_phi + pred_phi_orig) / 2 | |
pred_shape = (pred_shape + pred_shape_orig) / 2 | |
hybrik_output = smpl_layer( | |
pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * 2, | |
betas=pred_shape.type(self.smpl_dtype), | |
phis=pred_phi.type(self.smpl_dtype), | |
global_orient=None, | |
return_verts=True) | |
pred_vertices = hybrik_output['vertices'].float() | |
# -0.5 ~ 0.5 | |
pred_xyz_jts_24_struct = hybrik_output['joints'].float() / 2 | |
# -0.5 ~ 0.5 | |
pred_xyz_jts_17 = hybrik_output['joints_from_verts'].float() / 2 | |
pred_poses = hybrik_output['poses'].float().reshape( | |
batch_size, 24, 3, 3) | |
pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, 72) | |
pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72) | |
pred_xyz_jts_17 = pred_xyz_jts_17.reshape(batch_size, 17 * 3) | |
output = { | |
'pred_phi': pred_phi, | |
'pred_delta_shape': delta_shape, | |
'pred_shape': pred_shape, | |
'pred_pose': pred_poses, | |
'pred_uvd_jts': pred_uvd_jts_29_flat, | |
'pred_xyz_jts_24': pred_xyz_jts_24, | |
'pred_xyz_jts_24_struct': pred_xyz_jts_24_struct, | |
'pred_xyz_jts_17': pred_xyz_jts_17, | |
'pred_vertices': pred_vertices, | |
'maxvals': maxvals, | |
} | |
return output | |