"""This script is modified from [PARE](https://github.com/ mkocabas/PARE/tree/master/pare/models/layers). Original license please see docs/additional_licenses.md. """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.runner.base_module import BaseModule from torch.nn.modules.utils import _pair from detrsmpl.utils.geometry import rot6d_to_rotmat class LocallyConnected2d(nn.Module): """Locally Connected Layer. Args: in_channels (int): the in channel of the features. out_channels (int): the out channel of the features. output_size (List[int]): the output size of the features. kernel_size (int): the size of the kernel. stride (int): the stride of the kernel. Returns: attended_features (torch.Tensor): attended feature maps """ def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False): super(LocallyConnected2d, self).__init__() output_size = _pair(output_size) self.weight = nn.Parameter( torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2), requires_grad=True, ) if bias: self.bias = nn.Parameter(torch.randn(1, out_channels, output_size[0], output_size[1]), requires_grad=True) else: self.register_parameter('bias', None) self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) def forward(self, x): _, c, h, w = x.size() kh, kw = self.kernel_size dh, dw = self.stride x = x.unfold(2, kh, dh).unfold(3, kw, dw) x = x.contiguous().view(*x.size()[:-2], -1) # Sum in in_channel and kernel_size dims out = (x.unsqueeze(1) * self.weight).sum([2, -1]) if self.bias is not None: out += self.bias return out class KeypointAttention(nn.Module): """Keypoint Attention Layer. Args: use_conv (bool): whether to use conv for the attended feature map. Default: False in_channels (List[int]): the in channel of shape_cam features and pose features. Default: (256, 64) out_channels (List[int]): the out channel of shape_cam features and pose features. Default: (256, 64) Returns: attended_features (torch.Tensor): attended feature maps """ def __init__(self, use_conv=False, in_channels=(256, 64), out_channels=(256, 64), act='softmax', use_scale=False): super(KeypointAttention, self).__init__() self.use_conv = use_conv self.in_channels = in_channels self.out_channels = out_channels self.act = act self.use_scale = use_scale if use_conv: self.conv1x1_pose = nn.Conv1d(in_channels[0], out_channels[0], kernel_size=1) self.conv1x1_shape_cam = nn.Conv1d(in_channels[1], out_channels[1], kernel_size=1) def forward(self, features, heatmaps): batch_size, num_joints, height, width = heatmaps.shape if self.use_scale: scale = 1.0 / np.sqrt(height * width) heatmaps = heatmaps * scale if self.act == 'softmax': normalized_heatmap = F.softmax(heatmaps.reshape( batch_size, num_joints, -1), dim=-1) elif self.act == 'sigmoid': normalized_heatmap = torch.sigmoid( heatmaps.reshape(batch_size, num_joints, -1)) features = features.reshape(batch_size, -1, height * width) attended_features = torch.matmul(normalized_heatmap, features.transpose(2, 1)) attended_features = attended_features.transpose(2, 1) if self.use_conv: if attended_features.shape[1] == self.in_channels[0]: attended_features = self.conv1x1_pose(attended_features) else: attended_features = self.conv1x1_shape_cam(attended_features) return attended_features def interpolate(feat, uv): """ Args: feat (torch.Tensor): [B, C, H, W] image features uv (torch.Tensor): [B, 2, N] uv coordinates in the image plane, range [-1, 1] Returns: samples[:, :, :, 0] (torch.Tensor): [B, C, N] image features at the uv coordinates """ if uv.shape[-1] != 2: uv = uv.transpose(1, 2) # [B, N, 2] uv = uv.unsqueeze(2) # [B, N, 1, 2] # NOTE: for newer PyTorch, it seems that training # results are degraded due to implementation diff in F.grid_sample # for old versions, simply remove the aligned_corners argument. if int(torch.__version__.split('.')[1]) < 4: samples = torch.nn.functional.grid_sample(feat, uv) # [B, C, N, 1] else: samples = torch.nn.functional.grid_sample( feat, uv, align_corners=True) # [B, C, N, 1] return samples[:, :, :, 0] # [B, C, N] def _softmax(tensor, temperature, dim=-1): return F.softmax(tensor * temperature, dim=dim) def softargmax2d( heatmaps, temperature=None, normalize_keypoints=True, ): """Softargmax layer for heatmaps.""" dtype, device = heatmaps.dtype, heatmaps.device if temperature is None: temperature = torch.tensor(1.0, dtype=dtype, device=device) batch_size, num_channels, height, width = heatmaps.shape x = torch.arange(0, width, device=device, dtype=dtype).reshape( 1, 1, 1, width).expand(batch_size, -1, height, -1) y = torch.arange(0, height, device=device, dtype=dtype).reshape(1, 1, height, 1).expand(batch_size, -1, -1, width) # Should be Bx2xHxW points = torch.cat([x, y], dim=1) normalized_heatmap = _softmax(heatmaps.reshape(batch_size, num_channels, -1), temperature=temperature.reshape(1, -1, 1), dim=-1) # Should be BxJx2 keypoints = ( normalized_heatmap.reshape(batch_size, -1, 1, height * width) * points.reshape(batch_size, 1, 2, -1)).sum(dim=-1) if normalize_keypoints: # Normalize keypoints to [-1, 1] keypoints[:, :, 0] = (keypoints[:, :, 0] / (width - 1) * 2 - 1) keypoints[:, :, 1] = (keypoints[:, :, 1] / (height - 1) * 2 - 1) return keypoints, normalized_heatmap.reshape(batch_size, -1, height, width) class PareHead(BaseModule): def __init__( self, num_joints=24, num_input_features=480, softmax_temp=1.0, num_deconv_layers=3, num_deconv_filters=(256, 256, 256), num_deconv_kernels=(4, 4, 4), num_camera_params=3, num_features_smpl=64, final_conv_kernel=1, pose_mlp_num_layers=1, shape_mlp_num_layers=1, pose_mlp_hidden_size=256, shape_mlp_hidden_size=256, bn_momentum=0.1, use_heatmaps='part_segm', use_keypoint_attention=False, use_postconv_keypoint_attention=False, keypoint_attention_act='softmax', # softmax, sigmoid use_scale_keypoint_attention=False, backbone='hrnet_w32-conv', # hrnet, resnet smpl_mean_params=None, deconv_with_bias=False, ): """PARE parameters regressor head. This class is modified from. [PARE](hhttps://github.com/ mkocabas/PARE/blob/master/pare/models/head/pare_head.py). Original license please see docs/additional_licenses.md. Args: num_joints (int): Number of joints, should be 24 for smpl. num_input_features (int): Number of input featuremap channels. softmax_temp (float): Softmax tempreture num_deconv_layers (int): Number of deconvolution layers. num_deconv_filters (List[int]): Number of filters for each deconvolution layer, len(num_deconv_filters) == num_deconv_layers. num_deconv_kernels (List[int]): Kernel size for each deconvolution layer, len(num_deconv_kernels) == num_deconv_layers. num_camera_params (int): Number of predicted camera parameter dimension. num_features_smpl (int): Number of feature map channels. final_conv_kernel (int): Kernel size for the final deconvolution feature map channels. pose_mlp_num_layers (int): Number of mpl layers for pose parameter regression. shape_mlp_num_layers (int): Number of mpl layers for pose parameter regression. pose_mlp_hidden_size (int): Hidden size for pose mpl layers. shape_mlp_hidden_size (int): Hidden size for pose mpl layers. bn_momemtum (float): Momemtum for batch normalization. use_heatmaps (str): Types of heat maps to use. use_keypoint_attention (bool) Whether to use attention based on heat maps. keypoint_attention_act (str): Types of activation function for attention layers. use_scale_keypoint_attention (str): Whether to scale the attention according to the size of the attention map. deconv_with_bias (bool) Whether to deconv with bias. backbone (str): Types of the backbone. smpl_mean_params (str): File name of the mean SMPL parameters """ super(PareHead, self).__init__() self.backbone = backbone self.num_joints = num_joints self.deconv_with_bias = deconv_with_bias self.use_heatmaps = use_heatmaps self.pose_mlp_num_layers = pose_mlp_num_layers self.shape_mlp_num_layers = shape_mlp_num_layers self.pose_mlp_hidden_size = pose_mlp_hidden_size self.shape_mlp_hidden_size = shape_mlp_hidden_size self.use_keypoint_attention = use_keypoint_attention self.num_input_features = num_input_features self.bn_momentum = bn_momentum if self.use_heatmaps == 'part_segm': self.use_keypoint_attention = True if backbone.startswith('hrnet'): self.keypoint_deconv_layers = self._make_conv_layer( num_deconv_layers, num_deconv_filters, (3, ) * num_deconv_layers, ) self.num_input_features = num_input_features self.smpl_deconv_layers = self._make_conv_layer( num_deconv_layers, num_deconv_filters, (3, ) * num_deconv_layers, ) else: # part branch that estimates 2d keypoints conv_fn = self._make_deconv_layer self.keypoint_deconv_layers = conv_fn( num_deconv_layers, num_deconv_filters, num_deconv_kernels, ) # reset inplanes to 2048 -> final resnet layer self.num_input_features = num_input_features self.smpl_deconv_layers = conv_fn( num_deconv_layers, num_deconv_filters, num_deconv_kernels, ) pose_mlp_inp_dim = num_deconv_filters[-1] smpl_final_dim = num_features_smpl shape_mlp_inp_dim = num_joints * smpl_final_dim self.keypoint_final_layer = nn.Conv2d( in_channels=num_deconv_filters[-1], out_channels=num_joints + 1 if self.use_heatmaps in ('part_segm', 'part_segm_pool') else num_joints, kernel_size=final_conv_kernel, stride=1, padding=1 if final_conv_kernel == 3 else 0, ) self.smpl_final_layer = nn.Conv2d( in_channels=num_deconv_filters[-1], out_channels=smpl_final_dim, kernel_size=final_conv_kernel, stride=1, padding=1 if final_conv_kernel == 3 else 0, ) # temperature for softargmax function self.register_buffer('temperature', torch.tensor(softmax_temp)) mean_params = np.load(smpl_mean_params) init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) init_shape = torch.from_numpy( mean_params['shape'][:].astype('float32')).unsqueeze(0) init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) self.register_buffer('init_pose', init_pose) self.register_buffer('init_shape', init_shape) self.register_buffer('init_cam', init_cam) self.pose_mlp_inp_dim = pose_mlp_inp_dim self.shape_mlp_inp_dim = shape_mlp_inp_dim self.shape_mlp = self._get_shape_mlp(output_size=10) self.cam_mlp = self._get_shape_mlp(output_size=num_camera_params) self.pose_mlp = self._get_pose_mlp(num_joints=num_joints, output_size=6) self.keypoint_attention = KeypointAttention( use_conv=use_postconv_keypoint_attention, in_channels=(self.pose_mlp_inp_dim, smpl_final_dim), out_channels=(self.pose_mlp_inp_dim, smpl_final_dim), act=keypoint_attention_act, use_scale=use_scale_keypoint_attention, ) def _get_shape_mlp(self, output_size): """mlp layers for shape regression.""" if self.shape_mlp_num_layers == 1: return nn.Linear(self.shape_mlp_inp_dim, output_size) module_list = [] for i in range(self.shape_mlp_num_layers): if i == 0: module_list.append( nn.Linear(self.shape_mlp_inp_dim, self.shape_mlp_hidden_size)) elif i == self.shape_mlp_num_layers - 1: module_list.append( nn.Linear(self.shape_mlp_hidden_size, output_size)) else: module_list.append( nn.Linear(self.shape_mlp_hidden_size, self.shape_mlp_hidden_size)) return nn.Sequential(*module_list) def _get_pose_mlp(self, num_joints, output_size): """mlp layers for pose regression.""" if self.pose_mlp_num_layers == 1: return LocallyConnected2d( in_channels=self.pose_mlp_inp_dim, out_channels=output_size, output_size=[num_joints, 1], kernel_size=1, stride=1, ) module_list = [] for i in range(self.pose_mlp_num_layers): if i == 0: module_list.append( LocallyConnected2d( in_channels=self.pose_mlp_inp_dim, out_channels=self.pose_mlp_hidden_size, output_size=[num_joints, 1], kernel_size=1, stride=1, )) elif i == self.pose_mlp_num_layers - 1: module_list.append( LocallyConnected2d( in_channels=self.pose_mlp_hidden_size, out_channels=output_size, output_size=[num_joints, 1], kernel_size=1, stride=1, )) else: module_list.append( LocallyConnected2d( in_channels=self.pose_mlp_hidden_size, out_channels=self.pose_mlp_hidden_size, output_size=[num_joints, 1], kernel_size=1, stride=1, )) return nn.Sequential(*module_list) def _get_deconv_cfg(self, deconv_kernel): """get deconv padding, output padding according to kernel size.""" if deconv_kernel == 4: padding = 1 output_padding = 0 elif deconv_kernel == 3: padding = 1 output_padding = 1 elif deconv_kernel == 2: padding = 0 output_padding = 0 return deconv_kernel, padding, output_padding def _make_conv_layer(self, num_layers, num_filters, num_kernels): """make convolution layers.""" assert num_layers == len(num_filters), \ 'ERROR: num_conv_layers is different len(num_conv_filters)' assert num_layers == len(num_kernels), \ 'ERROR: num_conv_layers is different len(num_conv_filters)' layers = [] for i in range(num_layers): kernel, padding, output_padding = \ self._get_deconv_cfg(num_kernels[i]) planes = num_filters[i] layers.append( nn.Conv2d(in_channels=self.num_input_features, out_channels=planes, kernel_size=kernel, stride=1, padding=padding, bias=self.deconv_with_bias)) layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum)) layers.append(nn.ReLU(inplace=True)) self.num_input_features = planes return nn.Sequential(*layers) def _make_deconv_layer(self, num_layers, num_filters, num_kernels): """make deconvolution layers.""" assert num_layers == len(num_filters), \ 'ERROR: num_deconv_layers is different len(num_deconv_filters)' assert num_layers == len(num_kernels), \ 'ERROR: num_deconv_layers is different len(num_deconv_filters)' layers = [] for i in range(num_layers): kernel, padding, output_padding = \ self._get_deconv_cfg(num_kernels[i]) planes = num_filters[i] layers.append( nn.ConvTranspose2d(in_channels=self.num_input_features, out_channels=planes, kernel_size=kernel, stride=2, padding=padding, output_padding=output_padding, bias=self.deconv_with_bias)) layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum)) layers.append(nn.ReLU(inplace=True)) # if self.use_self_attention: # layers.append(SelfAttention(planes)) self.num_input_features = planes return nn.Sequential(*layers) def forward(self, features): batch_size = features.shape[0] init_pose = self.init_pose.expand(batch_size, -1) # N, Jx6 init_shape = self.init_shape.expand(batch_size, -1) init_cam = self.init_cam.expand(batch_size, -1) output = {} part_feats = self._get_2d_branch_feats(features) part_attention = self._get_part_attention_map(part_feats, output) smpl_feats = self._get_3d_smpl_feats(features, part_feats) point_local_feat, cam_shape_feats = self._get_local_feats( smpl_feats, part_attention, output) pred_pose, pred_shape, pred_cam = self._get_final_preds( point_local_feat, cam_shape_feats, init_pose, init_shape, init_cam) pred_rotmat = rot6d_to_rotmat(pred_pose).reshape(batch_size, 24, 3, 3) output.update({ 'pred_pose': pred_rotmat, 'pred_cam': pred_cam, 'pred_shape': pred_shape, }) return output def _get_local_feats(self, smpl_feats, part_attention, output): # 1x1 conv """get keypoints and camera features from backbone features.""" cam_shape_feats = self.smpl_final_layer(smpl_feats) if self.use_keypoint_attention: point_local_feat = self.keypoint_attention(smpl_feats, part_attention) cam_shape_feats = self.keypoint_attention(cam_shape_feats, part_attention) else: point_local_feat = interpolate(smpl_feats, output['pred_kp2d']) cam_shape_feats = interpolate(cam_shape_feats, output['pred_kp2d']) return point_local_feat, cam_shape_feats def _get_2d_branch_feats(self, features): """get part features from backbone features.""" part_feats = self.keypoint_deconv_layers(features) return part_feats def _get_3d_smpl_feats(self, features, part_feats): """get smpl feature maps from backbone features.""" smpl_feats = self.smpl_deconv_layers(features) return smpl_feats def _get_part_attention_map(self, part_feats, output): """get attention map from part feature map.""" heatmaps = self.keypoint_final_layer(part_feats) if self.use_heatmaps == 'part_segm': output['pred_segm_mask'] = heatmaps # remove the the background channel heatmaps = heatmaps[:, 1:, :, :] else: pred_kp2d, _ = softargmax2d(heatmaps, self.temperature) output['pred_kp2d'] = pred_kp2d output['pred_heatmaps_2d'] = heatmaps return heatmaps def _get_final_preds(self, pose_feats, cam_shape_feats, init_pose, init_shape, init_cam): """get final preds.""" return self._pare_get_final_preds(pose_feats, cam_shape_feats, init_pose, init_shape, init_cam) def _pare_get_final_preds(self, pose_feats, cam_shape_feats, init_pose, init_shape, init_cam): """get final preds.""" pose_feats = pose_feats.unsqueeze(-1) # if init_pose.shape[-1] == 6: # This means init_pose comes from a previous iteration init_pose = init_pose.transpose(2, 1).unsqueeze(-1) else: # This means init pose comes from mean pose init_pose = init_pose.reshape(init_pose.shape[0], 6, -1).unsqueeze(-1) shape_feats = cam_shape_feats shape_feats = torch.flatten(shape_feats, start_dim=1) pred_pose = self.pose_mlp(pose_feats) pred_cam = self.cam_mlp(shape_feats) pred_shape = self.shape_mlp(shape_feats) pred_pose = pred_pose.squeeze(-1).transpose(2, 1) # N, J, 6 return pred_pose, pred_shape, pred_cam