import os import pickle from abc import abstractmethod from typing import List, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import build_activation_layer, initialize from mmcv.runner.base_module import BaseModule from detrsmpl.utils.geometry import rot6d_to_rotmat class IterativeRegression(nn.Module): """Regressor for ExPose Head.""" def __init__(self, module, mean_param, num_stages=1, append_params=True, learn_mean=False, detach_mean=False, dim=1, **kwargs): super(IterativeRegression, self).__init__() self.module = module self._num_stages = num_stages self.dim = dim if learn_mean: self.register_parameter( 'mean_param', nn.Parameter(mean_param, requires_grad=True)) else: self.register_buffer('mean_param', mean_param) self.append_params = append_params self.detach_mean = detach_mean def get_mean(self): """Get the initial mean param.""" return self.mean_param.clone() @property def num_stages(self): return self._num_stages def forward(self, features: torch.Tensor, cond: Optional[torch.Tensor] = None): ''' Computes deltas on top of condition iteratively Parameters ---------- features: torch.Tensor Input features ''' batch_size = features.shape[0] expand_shape = [batch_size] + [-1] * len(features.shape[1:]) parameters = [] deltas = [] module_input = features if cond is None: cond = self.mean_param.expand(*expand_shape).clone() # Detach mean if self.detach_mean: cond = cond.detach() if self.append_params: assert features is not None, ( 'Features are none even though append_params is True') module_input = torch.cat([module_input, cond], dim=self.dim) deltas.append(self.module(module_input)) num_params = deltas[-1].shape[1] parameters.append(cond[:, :num_params].clone() + deltas[-1]) for stage_idx in range(1, self.num_stages): module_input = torch.cat([features, parameters[stage_idx - 1]], dim=-1) params_upd = self.module(module_input) deltas.append(params_upd) parameters.append(parameters[stage_idx - 1] + params_upd) return parameters class MLP(nn.Module): """MLP Args: input_dim (int): Input dim of MLP. output_dim (int): Output dim of MLP. layers (List): Layer dims. activ_type (str): Activation layer type. dropout (float): Dropout. gain (float): Xavier init gain value. """ def __init__( self, input_dim: int, output_dim: int, layers: List[int] = [], activ_type: str = 'relu', dropout: float = 0.5, gain: float = 0.01, ): super(MLP, self).__init__() curr_input_dim = input_dim self.num_layers = len(layers) self.blocks = nn.ModuleList() for layer_idx, layer_dim in enumerate(layers): if activ_type == 'none': active = None else: active = build_activation_layer( cfg=dict(type=activ_type, inplace=True)) linear = nn.Linear(curr_input_dim, layer_dim, bias=True) curr_input_dim = layer_dim layer = [] layer.append(linear) if active is not None: layer.append(active) if dropout > 0.0: layer.append(nn.Dropout(dropout)) block = nn.Sequential(*layer) self.add_module('layer_{:03d}'.format(layer_idx), block) self.blocks.append(block) self.output_layer = nn.Linear(curr_input_dim, output_dim) initialize(self.output_layer, init_cfg=dict(type='Xavier', gain=gain, distribution='uniform')) def forward(self, module_input): curr_input = module_input for block in self.blocks: curr_input = block(curr_input) return self.output_layer(curr_input) class ContinuousRotReprDecoder: """ExPose Decoder Decode latent representation to rotation. Args: num_angles (int): Joint num. dtype: dtype. mean (torch.tensor): Mean value for params. """ def __init__(self, num_angles, dtype=torch.float32, mean=None): self.num_angles = num_angles self.dtype = dtype if isinstance(mean, dict): mean = mean.get('cont_rot_repr', None) if mean is None: mean = torch.tensor([1.0, 0.0, 0.0, 1.0, 0.0, 0.0], dtype=self.dtype).unsqueeze(dim=0).expand( self.num_angles, -1).contiguous().view(-1) if not torch.is_tensor(mean): mean = torch.tensor(mean) mean = mean.reshape(-1, 6) if mean.shape[0] < self.num_angles: mean = mean.repeat(self.num_angles // mean.shape[0] + 1, 1).contiguous() mean = mean[:self.num_angles] elif mean.shape[0] > self.num_angles: mean = mean[:self.num_angles] mean = mean.reshape(-1) self.mean = mean def get_mean(self): return self.mean.clone() def get_dim_size(self): return self.num_angles * 6 def __call__(self, module_input): batch_size = module_input.shape[0] reshaped_input = module_input.view(-1, 6) rot_mats = rot6d_to_rotmat(reshaped_input) # aa = rot6d_to_aa(reshaped_input) # return aa.view(batch_size,-1,3) return rot_mats.view(batch_size, -1, 3, 3) class ExPoseHead(BaseModule): """General Head for ExPose.""" def __init__(self, init_cfg=None): super().__init__(init_cfg) def load_regressor(self, input_feat_dim: int = 2048, param_mean: torch.Tensor = None, regressor_cfg: dict = None): """Build regressor for ExPose Head.""" param_dim = param_mean.numel() regressor = MLP(input_feat_dim + param_dim, param_dim, **regressor_cfg) self.regressor = IterativeRegression(regressor, param_mean, num_stages=3) def load_param_decoder(self, mean_poses_dict): """Build decoders for each pose.""" start = 0 mean_lst = [] self.pose_param_decoders = {} for pose_param in self.pose_param_conf: pose_name = pose_param['name'] num_angles = pose_param['num_angles'] if pose_param['use_mean']: pose_decoder = ContinuousRotReprDecoder( num_angles, dtype=torch.float32, mean=mean_poses_dict.get(pose_name, None)) else: pose_decoder = ContinuousRotReprDecoder(num_angles, dtype=torch.float32, mean=None) self.pose_param_decoders['{}_decoder'.format( pose_name)] = pose_decoder pose_dim = pose_decoder.get_dim_size() pose_mean = pose_decoder.get_mean() if pose_param['rotate_axis_x']: pose_mean[3] = -1 idxs = list(range(start, start + pose_dim)) idxs = torch.tensor(idxs, dtype=torch.long) self.register_buffer('{}_idxs'.format(pose_name), idxs) start += pose_dim mean_lst.append(pose_mean.view(-1)) return start, mean_lst def get_camera_param(self, camera_cfg): """Build camera param.""" camera_pos_scale = camera_cfg.get('pos_func') if camera_pos_scale == 'softplus': camera_scale_func = F.softplus elif camera_pos_scale == 'exp': camera_scale_func = torch.exp elif camera_pos_scale == 'none' or camera_pos_scale == 'None': def func(x): return x camera_scale_func = func mean_scale = camera_cfg.get('mean_scale', 0.9) if camera_pos_scale == 'softplus': mean_scale = np.log(np.exp(mean_scale) - 1) elif camera_pos_scale == 'exp': mean_scale = np.log(mean_scale) camera_mean = torch.tensor([mean_scale, 0.0, 0.0], dtype=torch.float32) camera_param_dim = 3 return camera_mean, camera_param_dim, camera_scale_func def flat_params_to_dict(self, param_tensor): """Turn param tensors to dict.""" smplx_dict = {} raw_dict = {} for pose_param in self.pose_param_conf: pose_name = pose_param['name'] pose_idxs = getattr(self, f'{pose_name}_idxs') decoder = self.pose_param_decoders[f'{pose_name}_decoder'] pose = torch.index_select(param_tensor, 1, pose_idxs) raw_dict[f'raw_{pose_name}'] = pose.clone() smplx_dict[pose_name] = decoder(pose) return smplx_dict, raw_dict def get_mean(self, name, batch_size): """Get mean value of params.""" mean_param = self.regressor.get_mean().view(-1) if name is None: return mean_param.reshape(1, -1).expand(batch_size, -1) idxs = getattr(self, f'{name}_idxs') return mean_param[idxs].reshape(1, -1).expand(batch_size, -1) def get_num_betas(self): return self.num_betas def get_num_expression_coeffs(self): return self.num_expression_coeffs @abstractmethod def forward(self, features): pass class ExPoseBodyHead(ExPoseHead): """Head for ExPose Body Model.""" def __init__(self, init_cfg=None, num_betas: int = 10, num_expression_coeffs: int = 10, mean_pose_path: str = '', shape_mean_path: str = '', pose_param_conf: list = None, input_feat_dim: int = 2048, regressor_cfg: dict = None, camera_cfg: dict = None): super().__init__(init_cfg) self.num_betas = num_betas self.num_expression_coeffs = num_expression_coeffs # poses self.pose_param_conf = pose_param_conf mean_poses_dict = {} if os.path.exists(mean_pose_path): with open(mean_pose_path, 'rb') as f: mean_poses_dict = pickle.load(f) start, mean_lst = self.load_param_decoder(mean_poses_dict) # shape if os.path.exists(shape_mean_path): shape_mean = torch.from_numpy( np.load(shape_mean_path, allow_pickle=True)).to(dtype=torch.float32).reshape( 1, -1)[:, :num_betas].reshape(-1) else: shape_mean = torch.zeros([num_betas], dtype=torch.float32) shape_idxs = list(range(start, start + num_betas)) self.register_buffer('shape_idxs', torch.tensor(shape_idxs, dtype=torch.long)) start += num_betas mean_lst.append(shape_mean.view(-1)) # expression expression_mean = torch.zeros([num_expression_coeffs], dtype=torch.float32) expression_idxs = list(range(start, start + num_expression_coeffs)) self.register_buffer('expression_idxs', torch.tensor(expression_idxs, dtype=torch.long)) start += num_expression_coeffs mean_lst.append(expression_mean.view(-1)) # camera mean, dim, scale_func = self.get_camera_param(camera_cfg) self.camera_scale_func = scale_func camera_idxs = list(range(start, start + dim)) self.register_buffer('camera_idxs', torch.tensor(camera_idxs, dtype=torch.long)) start += dim mean_lst.append(mean) param_mean = torch.cat(mean_lst).view(1, -1) self.load_regressor(input_feat_dim, param_mean, regressor_cfg) def forward(self, features): """Forward function of ExPose Body Head. Args: features (List[torch.tensor]) : Output of restnet. cond : Initial params. If none, use the mean params. """ body_parameters = self.regressor(features)[-1] params_dict, raw_dict = self.flat_params_to_dict(body_parameters) params_dict['betas'] = torch.index_select(body_parameters, 1, self.shape_idxs) params_dict['expression'] = torch.index_select(body_parameters, 1, self.expression_idxs) camera_params = torch.index_select(body_parameters, 1, self.camera_idxs) scale = camera_params[:, 0:1] translation = camera_params[:, 1:3] scale = self.camera_scale_func(scale) camera_params = torch.cat([scale, translation], dim=1) return { 'pred_param': params_dict, 'pred_cam': camera_params, 'pred_raw': raw_dict } class ExPoseHandHead(ExPoseHead): """Head for ExPose Hand Model.""" def __init__(self, init_cfg=None, num_betas: int = 10, mean_pose_path: str = '', pose_param_conf: list = None, input_feat_dim: int = 2048, regressor_cfg: dict = None, camera_cfg: dict = None): super().__init__(init_cfg) self.num_betas = num_betas # poses self.pose_param_conf = pose_param_conf mean_poses_dict = {} if os.path.exists(mean_pose_path): with open(mean_pose_path, 'rb') as f: mean_poses_dict = pickle.load(f) start, mean_lst = self.load_param_decoder(mean_poses_dict) shape_mean = torch.zeros([num_betas], dtype=torch.float32) shape_idxs = list(range(start, start + num_betas)) self.register_buffer('shape_idxs', torch.tensor(shape_idxs, dtype=torch.long)) start += num_betas mean_lst.append(shape_mean.view(-1)) # camera mean, dim, scale_func = self.get_camera_param(camera_cfg) self.camera_scale_func = scale_func camera_idxs = list(range(start, start + dim)) self.register_buffer('camera_idxs', torch.tensor(camera_idxs, dtype=torch.long)) start += dim mean_lst.append(mean) param_mean = torch.cat(mean_lst).view(1, -1) self.load_regressor(input_feat_dim, param_mean, regressor_cfg) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) def forward(self, features, cond=None): """Forward function of ExPose Hand Head. Args: features (List[torch.tensor]) : Output of restnet. cond : Initial params. If none, use the mean params. """ batch_size = features[-1].size(0) features = self.avgpool(features[-1]).view(batch_size, -1) hand_parameters = self.regressor(features, cond=cond)[-1] params_dict, raw_dict = self.flat_params_to_dict(hand_parameters) params_dict['betas'] = torch.index_select(hand_parameters, 1, self.shape_idxs) camera_params = torch.index_select(hand_parameters, 1, self.camera_idxs) scale = camera_params[:, 0:1] translation = camera_params[:, 1:3] scale = self.camera_scale_func(scale) camera_params = torch.cat([scale, translation], dim=1) return { 'pred_param': params_dict, 'pred_cam': camera_params, 'pred_raw': raw_dict } class ExPoseFaceHead(ExPoseHead): """Head for ExPose Face Model.""" def __init__(self, init_cfg=None, num_betas: int = 10, num_expression_coeffs: int = 10, pose_param_conf: list = None, mean_pose_path: str = '', input_feat_dim: int = 2048, regressor_cfg: dict = None, camera_cfg: dict = None): super().__init__(init_cfg) self.num_betas = num_betas self.num_expression_coeffs = num_expression_coeffs # poses self.pose_param_conf = pose_param_conf mean_poses_dict = {} if os.path.exists(mean_pose_path): with open(mean_pose_path, 'rb') as f: mean_poses_dict = pickle.load(f) start, mean_lst = self.load_param_decoder(mean_poses_dict) # shape shape_mean = torch.zeros([num_betas], dtype=torch.float32) shape_idxs = list(range(start, start + num_betas)) self.register_buffer('shape_idxs', torch.tensor(shape_idxs, dtype=torch.long)) start += num_betas mean_lst.append(shape_mean.view(-1)) # expression expression_mean = torch.zeros([num_expression_coeffs], dtype=torch.float32) expression_idxs = list(range(start, start + num_expression_coeffs)) self.register_buffer('expression_idxs', torch.tensor(expression_idxs, dtype=torch.long)) start += num_expression_coeffs mean_lst.append(expression_mean.view(-1)) # camera mean, dim, scale_func = self.get_camera_param(camera_cfg) self.camera_scale_func = scale_func camera_idxs = list(range(start, start + dim)) self.register_buffer('camera_idxs', torch.tensor(camera_idxs, dtype=torch.long)) start += dim mean_lst.append(mean) param_mean = torch.cat(mean_lst).view(1, -1) self.load_regressor(input_feat_dim, param_mean, regressor_cfg) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) def forward(self, features, cond=None): """Forward function of ExPose Face Head. Args: features (List[torch.tensor]) : Output of restnet. cond : Initial params. If none, use the mean params. """ batch_size = features[-1].size(0) features = self.avgpool(features[-1]).view(batch_size, -1) head_parameters = self.regressor(features, cond=cond)[-1] params_dict, raw_dict = self.flat_params_to_dict(head_parameters) params_dict['betas'] = torch.index_select(head_parameters, 1, self.shape_idxs) params_dict['expression'] = torch.index_select(head_parameters, 1, self.expression_idxs) camera_params = torch.index_select(head_parameters, 1, self.camera_idxs) scale = camera_params[:, 0:1] translation = camera_params[:, 1:3] scale = self.camera_scale_func(scale) camera_params = torch.cat([scale, translation], dim=1) return { 'pred_param': params_dict, 'pred_cam': camera_params, 'pred_raw': raw_dict }