Spaces:
Running
on
L40S
Running
on
L40S
import numpy as np | |
import torch | |
import torch.nn as nn | |
from mmcv.runner.base_module import BaseModule | |
from detrsmpl.utils.geometry import rot6d_to_rotmat | |
class HMRHead(BaseModule): | |
def __init__(self, | |
feat_dim, | |
smpl_mean_params=None, | |
npose=144, | |
nbeta=10, | |
ncam=3, | |
hdim=1024, | |
init_cfg=None): | |
super(HMRHead, self).__init__(init_cfg=init_cfg) | |
self.fc1 = nn.Linear(feat_dim + npose + nbeta + ncam, hdim) | |
self.drop1 = nn.Dropout() | |
self.fc2 = nn.Linear(hdim, hdim) | |
self.drop2 = nn.Dropout() | |
self.decpose = nn.Linear(hdim, npose) | |
self.decshape = nn.Linear(hdim, nbeta) | |
self.deccam = nn.Linear(hdim, ncam) | |
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) | |
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) | |
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) | |
if smpl_mean_params is None: | |
init_pose = torch.zeros([1, npose]) | |
init_shape = torch.zeros([1, nbeta]) | |
init_cam = torch.FloatTensor([[1, 0, 0]]) | |
else: | |
mean_params = np.load(smpl_mean_params) | |
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) | |
init_shape = torch.from_numpy( | |
mean_params['shape'][:].astype('float32')).unsqueeze(0) | |
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) | |
self.register_buffer('init_pose', init_pose) | |
self.register_buffer('init_shape', init_shape) | |
self.register_buffer('init_cam', init_cam) | |
def forward(self, | |
x, | |
init_pose=None, | |
init_shape=None, | |
init_cam=None, | |
n_iter=3): | |
# hmr head only support one layer feature | |
if isinstance(x, list) or isinstance(x, tuple): | |
x = x[-1] | |
output_seq = False | |
if len(x.shape) == 4: | |
# use feature from the last layer of the backbone | |
# apply global average pooling on the feature map | |
x = x.mean(dim=-1).mean(dim=-1) | |
elif len(x.shape) == 3: | |
# temporal feature | |
output_seq = True | |
B, T, L = x.shape | |
x = x.view(-1, L) | |
batch_size = x.shape[0] | |
if init_pose is None: | |
init_pose = self.init_pose.expand(batch_size, -1) | |
if init_shape is None: | |
init_shape = self.init_shape.expand(batch_size, -1) | |
if init_cam is None: | |
init_cam = self.init_cam.expand(batch_size, -1) | |
pred_pose = init_pose | |
pred_shape = init_shape | |
pred_cam = init_cam | |
for i in range(n_iter): | |
xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1) | |
xc = self.fc1(xc) | |
xc = self.drop1(xc) | |
xc = self.fc2(xc) | |
xc = self.drop2(xc) | |
pred_pose = self.decpose(xc) + pred_pose | |
pred_shape = self.decshape(xc) + pred_shape | |
pred_cam = self.deccam(xc) + pred_cam | |
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) | |
if output_seq: | |
pred_rotmat = pred_rotmat.view(B, T, 24, 3, 3) | |
pred_shape = pred_shape.view(B, T, 10) | |
pred_cam = pred_cam.view(B, T, 3) | |
output = { | |
'pred_pose': pred_rotmat, | |
'pred_shape': pred_shape, | |
'pred_cam': pred_cam | |
} | |
return output | |