Spaces:
Running
on
L40S
Running
on
L40S
"""This script is modified from [PARE](https://github.com/ | |
mkocabas/PARE/tree/master/pare/models/layers). | |
Original license please see docs/additional_licenses.md. | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.runner.base_module import BaseModule | |
from torch.nn.modules.utils import _pair | |
from detrsmpl.utils.geometry import rot6d_to_rotmat | |
class LocallyConnected2d(nn.Module): | |
"""Locally Connected Layer. | |
Args: | |
in_channels (int): | |
the in channel of the features. | |
out_channels (int): | |
the out channel of the features. | |
output_size (List[int]): | |
the output size of the features. | |
kernel_size (int): | |
the size of the kernel. | |
stride (int): | |
the stride of the kernel. | |
Returns: | |
attended_features (torch.Tensor): | |
attended feature maps | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
output_size, | |
kernel_size, | |
stride, | |
bias=False): | |
super(LocallyConnected2d, self).__init__() | |
output_size = _pair(output_size) | |
self.weight = nn.Parameter( | |
torch.randn(1, out_channels, in_channels, output_size[0], | |
output_size[1], kernel_size**2), | |
requires_grad=True, | |
) | |
if bias: | |
self.bias = nn.Parameter(torch.randn(1, out_channels, | |
output_size[0], | |
output_size[1]), | |
requires_grad=True) | |
else: | |
self.register_parameter('bias', None) | |
self.kernel_size = _pair(kernel_size) | |
self.stride = _pair(stride) | |
def forward(self, x): | |
_, c, h, w = x.size() | |
kh, kw = self.kernel_size | |
dh, dw = self.stride | |
x = x.unfold(2, kh, dh).unfold(3, kw, dw) | |
x = x.contiguous().view(*x.size()[:-2], -1) | |
# Sum in in_channel and kernel_size dims | |
out = (x.unsqueeze(1) * self.weight).sum([2, -1]) | |
if self.bias is not None: | |
out += self.bias | |
return out | |
class KeypointAttention(nn.Module): | |
"""Keypoint Attention Layer. | |
Args: | |
use_conv (bool): | |
whether to use conv for the attended feature map. | |
Default: False | |
in_channels (List[int]): | |
the in channel of shape_cam features and pose features. | |
Default: (256, 64) | |
out_channels (List[int]): | |
the out channel of shape_cam features and pose features. | |
Default: (256, 64) | |
Returns: | |
attended_features (torch.Tensor): | |
attended feature maps | |
""" | |
def __init__(self, | |
use_conv=False, | |
in_channels=(256, 64), | |
out_channels=(256, 64), | |
act='softmax', | |
use_scale=False): | |
super(KeypointAttention, self).__init__() | |
self.use_conv = use_conv | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.act = act | |
self.use_scale = use_scale | |
if use_conv: | |
self.conv1x1_pose = nn.Conv1d(in_channels[0], | |
out_channels[0], | |
kernel_size=1) | |
self.conv1x1_shape_cam = nn.Conv1d(in_channels[1], | |
out_channels[1], | |
kernel_size=1) | |
def forward(self, features, heatmaps): | |
batch_size, num_joints, height, width = heatmaps.shape | |
if self.use_scale: | |
scale = 1.0 / np.sqrt(height * width) | |
heatmaps = heatmaps * scale | |
if self.act == 'softmax': | |
normalized_heatmap = F.softmax(heatmaps.reshape( | |
batch_size, num_joints, -1), | |
dim=-1) | |
elif self.act == 'sigmoid': | |
normalized_heatmap = torch.sigmoid( | |
heatmaps.reshape(batch_size, num_joints, -1)) | |
features = features.reshape(batch_size, -1, height * width) | |
attended_features = torch.matmul(normalized_heatmap, | |
features.transpose(2, 1)) | |
attended_features = attended_features.transpose(2, 1) | |
if self.use_conv: | |
if attended_features.shape[1] == self.in_channels[0]: | |
attended_features = self.conv1x1_pose(attended_features) | |
else: | |
attended_features = self.conv1x1_shape_cam(attended_features) | |
return attended_features | |
def interpolate(feat, uv): | |
""" | |
Args: | |
feat (torch.Tensor): [B, C, H, W] image features | |
uv (torch.Tensor): [B, 2, N] uv coordinates | |
in the image plane, range [-1, 1] | |
Returns: | |
samples[:, :, :, 0] (torch.Tensor): | |
[B, C, N] image features at the uv coordinates | |
""" | |
if uv.shape[-1] != 2: | |
uv = uv.transpose(1, 2) # [B, N, 2] | |
uv = uv.unsqueeze(2) # [B, N, 1, 2] | |
# NOTE: for newer PyTorch, it seems that training | |
# results are degraded due to implementation diff in F.grid_sample | |
# for old versions, simply remove the aligned_corners argument. | |
if int(torch.__version__.split('.')[1]) < 4: | |
samples = torch.nn.functional.grid_sample(feat, uv) # [B, C, N, 1] | |
else: | |
samples = torch.nn.functional.grid_sample( | |
feat, uv, align_corners=True) # [B, C, N, 1] | |
return samples[:, :, :, 0] # [B, C, N] | |
def _softmax(tensor, temperature, dim=-1): | |
return F.softmax(tensor * temperature, dim=dim) | |
def softargmax2d( | |
heatmaps, | |
temperature=None, | |
normalize_keypoints=True, | |
): | |
"""Softargmax layer for heatmaps.""" | |
dtype, device = heatmaps.dtype, heatmaps.device | |
if temperature is None: | |
temperature = torch.tensor(1.0, dtype=dtype, device=device) | |
batch_size, num_channels, height, width = heatmaps.shape | |
x = torch.arange(0, width, device=device, dtype=dtype).reshape( | |
1, 1, 1, width).expand(batch_size, -1, height, -1) | |
y = torch.arange(0, height, device=device, | |
dtype=dtype).reshape(1, 1, height, | |
1).expand(batch_size, -1, -1, width) | |
# Should be Bx2xHxW | |
points = torch.cat([x, y], dim=1) | |
normalized_heatmap = _softmax(heatmaps.reshape(batch_size, num_channels, | |
-1), | |
temperature=temperature.reshape(1, -1, 1), | |
dim=-1) | |
# Should be BxJx2 | |
keypoints = ( | |
normalized_heatmap.reshape(batch_size, -1, 1, height * width) * | |
points.reshape(batch_size, 1, 2, -1)).sum(dim=-1) | |
if normalize_keypoints: | |
# Normalize keypoints to [-1, 1] | |
keypoints[:, :, 0] = (keypoints[:, :, 0] / (width - 1) * 2 - 1) | |
keypoints[:, :, 1] = (keypoints[:, :, 1] / (height - 1) * 2 - 1) | |
return keypoints, normalized_heatmap.reshape(batch_size, -1, height, width) | |
class PareHead(BaseModule): | |
def __init__( | |
self, | |
num_joints=24, | |
num_input_features=480, | |
softmax_temp=1.0, | |
num_deconv_layers=3, | |
num_deconv_filters=(256, 256, 256), | |
num_deconv_kernels=(4, 4, 4), | |
num_camera_params=3, | |
num_features_smpl=64, | |
final_conv_kernel=1, | |
pose_mlp_num_layers=1, | |
shape_mlp_num_layers=1, | |
pose_mlp_hidden_size=256, | |
shape_mlp_hidden_size=256, | |
bn_momentum=0.1, | |
use_heatmaps='part_segm', | |
use_keypoint_attention=False, | |
use_postconv_keypoint_attention=False, | |
keypoint_attention_act='softmax', # softmax, sigmoid | |
use_scale_keypoint_attention=False, | |
backbone='hrnet_w32-conv', # hrnet, resnet | |
smpl_mean_params=None, | |
deconv_with_bias=False, | |
): | |
"""PARE parameters regressor head. This class is modified from. | |
[PARE](hhttps://github.com/ | |
mkocabas/PARE/blob/master/pare/models/head/pare_head.py). Original | |
license please see docs/additional_licenses.md. | |
Args: | |
num_joints (int): | |
Number of joints, should be 24 for smpl. | |
num_input_features (int): | |
Number of input featuremap channels. | |
softmax_temp (float): | |
Softmax tempreture | |
num_deconv_layers (int): | |
Number of deconvolution layers. | |
num_deconv_filters (List[int]): | |
Number of filters for each deconvolution layer, | |
len(num_deconv_filters) == num_deconv_layers. | |
num_deconv_kernels (List[int]): | |
Kernel size for each deconvolution layer, | |
len(num_deconv_kernels) == num_deconv_layers. | |
num_camera_params (int): | |
Number of predicted camera parameter dimension. | |
num_features_smpl (int): | |
Number of feature map channels. | |
final_conv_kernel (int): | |
Kernel size for the final deconvolution feature map channels. | |
pose_mlp_num_layers (int): | |
Number of mpl layers for pose parameter regression. | |
shape_mlp_num_layers (int): | |
Number of mpl layers for pose parameter regression. | |
pose_mlp_hidden_size (int): | |
Hidden size for pose mpl layers. | |
shape_mlp_hidden_size (int): | |
Hidden size for pose mpl layers. | |
bn_momemtum (float): | |
Momemtum for batch normalization. | |
use_heatmaps (str): | |
Types of heat maps to use. | |
use_keypoint_attention (bool) | |
Whether to use attention based on heat maps. | |
keypoint_attention_act (str): | |
Types of activation function for attention layers. | |
use_scale_keypoint_attention (str): | |
Whether to scale the attention | |
according to the size of the attention map. | |
deconv_with_bias (bool) | |
Whether to deconv with bias. | |
backbone (str): | |
Types of the backbone. | |
smpl_mean_params (str): | |
File name of the mean SMPL parameters | |
""" | |
super(PareHead, self).__init__() | |
self.backbone = backbone | |
self.num_joints = num_joints | |
self.deconv_with_bias = deconv_with_bias | |
self.use_heatmaps = use_heatmaps | |
self.pose_mlp_num_layers = pose_mlp_num_layers | |
self.shape_mlp_num_layers = shape_mlp_num_layers | |
self.pose_mlp_hidden_size = pose_mlp_hidden_size | |
self.shape_mlp_hidden_size = shape_mlp_hidden_size | |
self.use_keypoint_attention = use_keypoint_attention | |
self.num_input_features = num_input_features | |
self.bn_momentum = bn_momentum | |
if self.use_heatmaps == 'part_segm': | |
self.use_keypoint_attention = True | |
if backbone.startswith('hrnet'): | |
self.keypoint_deconv_layers = self._make_conv_layer( | |
num_deconv_layers, | |
num_deconv_filters, | |
(3, ) * num_deconv_layers, | |
) | |
self.num_input_features = num_input_features | |
self.smpl_deconv_layers = self._make_conv_layer( | |
num_deconv_layers, | |
num_deconv_filters, | |
(3, ) * num_deconv_layers, | |
) | |
else: | |
# part branch that estimates 2d keypoints | |
conv_fn = self._make_deconv_layer | |
self.keypoint_deconv_layers = conv_fn( | |
num_deconv_layers, | |
num_deconv_filters, | |
num_deconv_kernels, | |
) | |
# reset inplanes to 2048 -> final resnet layer | |
self.num_input_features = num_input_features | |
self.smpl_deconv_layers = conv_fn( | |
num_deconv_layers, | |
num_deconv_filters, | |
num_deconv_kernels, | |
) | |
pose_mlp_inp_dim = num_deconv_filters[-1] | |
smpl_final_dim = num_features_smpl | |
shape_mlp_inp_dim = num_joints * smpl_final_dim | |
self.keypoint_final_layer = nn.Conv2d( | |
in_channels=num_deconv_filters[-1], | |
out_channels=num_joints + | |
1 if self.use_heatmaps in ('part_segm', | |
'part_segm_pool') else num_joints, | |
kernel_size=final_conv_kernel, | |
stride=1, | |
padding=1 if final_conv_kernel == 3 else 0, | |
) | |
self.smpl_final_layer = nn.Conv2d( | |
in_channels=num_deconv_filters[-1], | |
out_channels=smpl_final_dim, | |
kernel_size=final_conv_kernel, | |
stride=1, | |
padding=1 if final_conv_kernel == 3 else 0, | |
) | |
# temperature for softargmax function | |
self.register_buffer('temperature', torch.tensor(softmax_temp)) | |
mean_params = np.load(smpl_mean_params) | |
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) | |
init_shape = torch.from_numpy( | |
mean_params['shape'][:].astype('float32')).unsqueeze(0) | |
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) | |
self.register_buffer('init_pose', init_pose) | |
self.register_buffer('init_shape', init_shape) | |
self.register_buffer('init_cam', init_cam) | |
self.pose_mlp_inp_dim = pose_mlp_inp_dim | |
self.shape_mlp_inp_dim = shape_mlp_inp_dim | |
self.shape_mlp = self._get_shape_mlp(output_size=10) | |
self.cam_mlp = self._get_shape_mlp(output_size=num_camera_params) | |
self.pose_mlp = self._get_pose_mlp(num_joints=num_joints, | |
output_size=6) | |
self.keypoint_attention = KeypointAttention( | |
use_conv=use_postconv_keypoint_attention, | |
in_channels=(self.pose_mlp_inp_dim, smpl_final_dim), | |
out_channels=(self.pose_mlp_inp_dim, smpl_final_dim), | |
act=keypoint_attention_act, | |
use_scale=use_scale_keypoint_attention, | |
) | |
def _get_shape_mlp(self, output_size): | |
"""mlp layers for shape regression.""" | |
if self.shape_mlp_num_layers == 1: | |
return nn.Linear(self.shape_mlp_inp_dim, output_size) | |
module_list = [] | |
for i in range(self.shape_mlp_num_layers): | |
if i == 0: | |
module_list.append( | |
nn.Linear(self.shape_mlp_inp_dim, | |
self.shape_mlp_hidden_size)) | |
elif i == self.shape_mlp_num_layers - 1: | |
module_list.append( | |
nn.Linear(self.shape_mlp_hidden_size, output_size)) | |
else: | |
module_list.append( | |
nn.Linear(self.shape_mlp_hidden_size, | |
self.shape_mlp_hidden_size)) | |
return nn.Sequential(*module_list) | |
def _get_pose_mlp(self, num_joints, output_size): | |
"""mlp layers for pose regression.""" | |
if self.pose_mlp_num_layers == 1: | |
return LocallyConnected2d( | |
in_channels=self.pose_mlp_inp_dim, | |
out_channels=output_size, | |
output_size=[num_joints, 1], | |
kernel_size=1, | |
stride=1, | |
) | |
module_list = [] | |
for i in range(self.pose_mlp_num_layers): | |
if i == 0: | |
module_list.append( | |
LocallyConnected2d( | |
in_channels=self.pose_mlp_inp_dim, | |
out_channels=self.pose_mlp_hidden_size, | |
output_size=[num_joints, 1], | |
kernel_size=1, | |
stride=1, | |
)) | |
elif i == self.pose_mlp_num_layers - 1: | |
module_list.append( | |
LocallyConnected2d( | |
in_channels=self.pose_mlp_hidden_size, | |
out_channels=output_size, | |
output_size=[num_joints, 1], | |
kernel_size=1, | |
stride=1, | |
)) | |
else: | |
module_list.append( | |
LocallyConnected2d( | |
in_channels=self.pose_mlp_hidden_size, | |
out_channels=self.pose_mlp_hidden_size, | |
output_size=[num_joints, 1], | |
kernel_size=1, | |
stride=1, | |
)) | |
return nn.Sequential(*module_list) | |
def _get_deconv_cfg(self, deconv_kernel): | |
"""get deconv padding, output padding according to kernel size.""" | |
if deconv_kernel == 4: | |
padding = 1 | |
output_padding = 0 | |
elif deconv_kernel == 3: | |
padding = 1 | |
output_padding = 1 | |
elif deconv_kernel == 2: | |
padding = 0 | |
output_padding = 0 | |
return deconv_kernel, padding, output_padding | |
def _make_conv_layer(self, num_layers, num_filters, num_kernels): | |
"""make convolution layers.""" | |
assert num_layers == len(num_filters), \ | |
'ERROR: num_conv_layers is different len(num_conv_filters)' | |
assert num_layers == len(num_kernels), \ | |
'ERROR: num_conv_layers is different len(num_conv_filters)' | |
layers = [] | |
for i in range(num_layers): | |
kernel, padding, output_padding = \ | |
self._get_deconv_cfg(num_kernels[i]) | |
planes = num_filters[i] | |
layers.append( | |
nn.Conv2d(in_channels=self.num_input_features, | |
out_channels=planes, | |
kernel_size=kernel, | |
stride=1, | |
padding=padding, | |
bias=self.deconv_with_bias)) | |
layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum)) | |
layers.append(nn.ReLU(inplace=True)) | |
self.num_input_features = planes | |
return nn.Sequential(*layers) | |
def _make_deconv_layer(self, num_layers, num_filters, num_kernels): | |
"""make deconvolution layers.""" | |
assert num_layers == len(num_filters), \ | |
'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |
assert num_layers == len(num_kernels), \ | |
'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |
layers = [] | |
for i in range(num_layers): | |
kernel, padding, output_padding = \ | |
self._get_deconv_cfg(num_kernels[i]) | |
planes = num_filters[i] | |
layers.append( | |
nn.ConvTranspose2d(in_channels=self.num_input_features, | |
out_channels=planes, | |
kernel_size=kernel, | |
stride=2, | |
padding=padding, | |
output_padding=output_padding, | |
bias=self.deconv_with_bias)) | |
layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum)) | |
layers.append(nn.ReLU(inplace=True)) | |
# if self.use_self_attention: | |
# layers.append(SelfAttention(planes)) | |
self.num_input_features = planes | |
return nn.Sequential(*layers) | |
def forward(self, features): | |
batch_size = features.shape[0] | |
init_pose = self.init_pose.expand(batch_size, -1) # N, Jx6 | |
init_shape = self.init_shape.expand(batch_size, -1) | |
init_cam = self.init_cam.expand(batch_size, -1) | |
output = {} | |
part_feats = self._get_2d_branch_feats(features) | |
part_attention = self._get_part_attention_map(part_feats, output) | |
smpl_feats = self._get_3d_smpl_feats(features, part_feats) | |
point_local_feat, cam_shape_feats = self._get_local_feats( | |
smpl_feats, part_attention, output) | |
pred_pose, pred_shape, pred_cam = self._get_final_preds( | |
point_local_feat, cam_shape_feats, init_pose, init_shape, init_cam) | |
pred_rotmat = rot6d_to_rotmat(pred_pose).reshape(batch_size, 24, 3, 3) | |
output.update({ | |
'pred_pose': pred_rotmat, | |
'pred_cam': pred_cam, | |
'pred_shape': pred_shape, | |
}) | |
return output | |
def _get_local_feats(self, smpl_feats, part_attention, output): | |
# 1x1 conv | |
"""get keypoints and camera features from backbone features.""" | |
cam_shape_feats = self.smpl_final_layer(smpl_feats) | |
if self.use_keypoint_attention: | |
point_local_feat = self.keypoint_attention(smpl_feats, | |
part_attention) | |
cam_shape_feats = self.keypoint_attention(cam_shape_feats, | |
part_attention) | |
else: | |
point_local_feat = interpolate(smpl_feats, output['pred_kp2d']) | |
cam_shape_feats = interpolate(cam_shape_feats, output['pred_kp2d']) | |
return point_local_feat, cam_shape_feats | |
def _get_2d_branch_feats(self, features): | |
"""get part features from backbone features.""" | |
part_feats = self.keypoint_deconv_layers(features) | |
return part_feats | |
def _get_3d_smpl_feats(self, features, part_feats): | |
"""get smpl feature maps from backbone features.""" | |
smpl_feats = self.smpl_deconv_layers(features) | |
return smpl_feats | |
def _get_part_attention_map(self, part_feats, output): | |
"""get attention map from part feature map.""" | |
heatmaps = self.keypoint_final_layer(part_feats) | |
if self.use_heatmaps == 'part_segm': | |
output['pred_segm_mask'] = heatmaps | |
# remove the the background channel | |
heatmaps = heatmaps[:, 1:, :, :] | |
else: | |
pred_kp2d, _ = softargmax2d(heatmaps, self.temperature) | |
output['pred_kp2d'] = pred_kp2d | |
output['pred_heatmaps_2d'] = heatmaps | |
return heatmaps | |
def _get_final_preds(self, pose_feats, cam_shape_feats, init_pose, | |
init_shape, init_cam): | |
"""get final preds.""" | |
return self._pare_get_final_preds(pose_feats, cam_shape_feats, | |
init_pose, init_shape, init_cam) | |
def _pare_get_final_preds(self, pose_feats, cam_shape_feats, init_pose, | |
init_shape, init_cam): | |
"""get final preds.""" | |
pose_feats = pose_feats.unsqueeze(-1) # | |
if init_pose.shape[-1] == 6: | |
# This means init_pose comes from a previous iteration | |
init_pose = init_pose.transpose(2, 1).unsqueeze(-1) | |
else: | |
# This means init pose comes from mean pose | |
init_pose = init_pose.reshape(init_pose.shape[0], 6, | |
-1).unsqueeze(-1) | |
shape_feats = cam_shape_feats | |
shape_feats = torch.flatten(shape_feats, start_dim=1) | |
pred_pose = self.pose_mlp(pose_feats) | |
pred_cam = self.cam_mlp(shape_feats) | |
pred_shape = self.shape_mlp(shape_feats) | |
pred_pose = pred_pose.squeeze(-1).transpose(2, 1) # N, J, 6 | |
return pred_pose, pred_shape, pred_cam | |