Spaces:
Running
on
L40S
Running
on
L40S
# isort: skip_file | |
from abc import ABCMeta | |
import torch | |
from detrsmpl.data.datasets.pipelines.hybrik_transforms import heatmap2coord | |
from detrsmpl.utils.transforms import rotmat_to_quat | |
from ..backbones.builder import build_backbone | |
from ..body_models.builder import build_body_model | |
from ..heads.builder import build_head | |
from ..losses.builder import build_loss | |
from ..necks.builder import build_neck | |
from .base_architecture import BaseArchitecture | |
def set_requires_grad(nets, requires_grad=False): | |
"""Set requies_grad for all the networks. | |
Args: | |
nets (nn.Module | list[nn.Module]): A list of networks or a single | |
network. | |
requires_grad (bool): Whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
class HybrIK_trainer(BaseArchitecture, metaclass=ABCMeta): | |
"""Hybrik_trainer Architecture. | |
Args: | |
backbone (dict | None, optional): Backbone config dict. Default: None. | |
neck (dict | None, optional): Neck config dict. Default: None | |
head (dict | None, optional): Regressor config dict. Default: None. | |
body_model (dict | None, optional): SMPL config dict. Default: None. | |
loss_beta (dict | None, optional): Losses config dict for | |
beta (shape parameters) estimation. Default: None | |
loss_theta (dict | None, optional): Losses config dict for | |
theta (pose parameters) estimation. Default: None | |
loss_twist (dict | None, optional): Losses config dict | |
for twist angles estimation. Default: None | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
backbone=None, | |
neck=None, | |
head=None, | |
body_model=None, | |
loss_beta=None, | |
loss_theta=None, | |
loss_twist=None, | |
loss_uvd=None, | |
init_cfg=None): | |
super(HybrIK_trainer, self).__init__(init_cfg) | |
self.backbone = build_backbone(backbone) | |
self.neck = build_neck(neck) | |
self.head = build_head(head) | |
self.smpl = build_body_model(body_model) | |
self.loss_beta = build_loss(loss_beta) | |
self.loss_theta = build_loss(loss_theta) | |
self.loss_twist = build_loss(loss_twist) | |
self.loss_uvd = build_loss(loss_uvd) | |
self.head._initialize() | |
def forward_train(self, img, img_metas, **kwargs): | |
"""Train step function. | |
In this function, train step is carried out | |
with following the pipeline: | |
1. extract features with the backbone | |
2. feed the extracted features into the head to | |
predicte beta, theta, twist angle, and heatmap (uvd map) | |
3. compute regression losses of the predictions | |
and optimize backbone and head | |
Args: | |
img (torch.Tensor): Batch of data as input. | |
kwargs (dict): Dict with ground-truth | |
Returns: | |
output (dict): Dict with loss, information for logger, | |
the number of samples. | |
""" | |
labels = {} | |
labels['trans_inv'] = kwargs['trans_inv'] | |
labels['intrinsic_param'] = kwargs['intrinsic_param'] | |
labels['joint_root'] = kwargs['joint_root'] | |
labels['depth_factor'] = kwargs['depth_factor'] | |
labels['target_uvd_29'] = kwargs['target_uvd_29'] | |
labels['target_xyz_24'] = kwargs['target_xyz_24'] | |
labels['target_weight_24'] = kwargs['target_weight_24'] | |
labels['target_weight_29'] = kwargs['target_weight_29'] | |
labels['target_xyz_17'] = kwargs['target_xyz_17'] | |
labels['target_weight_17'] = kwargs['target_weight_17'] | |
labels['target_theta'] = kwargs['target_theta'] | |
labels['target_beta'] = kwargs['target_beta'] | |
labels['target_smpl_weight'] = kwargs['target_smpl_weight'] | |
labels['target_theta_weight'] = kwargs['target_theta_weight'] | |
labels['target_twist'] = kwargs['target_twist'] | |
labels['target_twist_weight'] = kwargs['target_twist_weight'] | |
# flip_output = kwargs.pop('is_flipped', None) | |
for k, _ in labels.items(): | |
labels[k] = labels[k].cuda() | |
trans_inv = labels.pop('trans_inv') | |
intrinsic_param = labels.pop('intrinsic_param') | |
joint_root = labels.pop('joint_root') | |
depth_factor = labels.pop('depth_factor') | |
if self.backbone is not None: | |
img = img.cuda().requires_grad_() | |
features = self.backbone(img) | |
features = features[0] | |
else: | |
features = img['features'] | |
if self.neck is not None: | |
features = self.neck(features) | |
predictions = self.head(features, trans_inv, intrinsic_param, | |
joint_root, depth_factor, self.smpl) | |
losses = self.compute_losses(predictions, labels) | |
return losses | |
def compute_losses(self, predictions, targets): | |
"""Compute regression losses for beta, theta, twist and uvd.""" | |
smpl_weight = targets['target_smpl_weight'] | |
losses = {} | |
if self.loss_beta is not None: | |
losses['loss_beta'] = self.loss_beta( | |
predictions['pred_shape'] * smpl_weight, | |
targets['target_beta'] * smpl_weight) | |
if self.loss_theta is not None: | |
pred_pose = rotmat_to_quat(predictions['pred_pose']).reshape( | |
-1, 96) | |
losses['loss_theta'] = self.loss_theta( | |
pred_pose * smpl_weight * targets['target_theta_weight'], | |
targets['target_theta'] * smpl_weight * | |
targets['target_theta_weight']) | |
if self.loss_twist is not None: | |
losses['loss_twist'] = self.loss_twist( | |
predictions['pred_phi'] * targets['target_twist_weight'], | |
targets['target_twist'] * targets['target_twist_weight']) | |
if self.loss_uvd is not None: | |
pred_uvd = predictions['pred_uvd_jts'] | |
target_uvd = targets['target_uvd_29'][:, :pred_uvd.shape[1]] | |
target_uvd_weight = targets['target_weight_29'][:, :pred_uvd. | |
shape[1]] | |
losses['loss_uvd'] = self.loss_uvd( | |
64 * predictions['pred_uvd_jts'], | |
64 * target_uvd, | |
target_uvd_weight, | |
avg_factor=target_uvd_weight.sum()) | |
return losses | |
def forward_test(self, img, img_metas, **kwargs): | |
"""Test step function. | |
In this function, train step is carried out | |
with following the pipeline: | |
1. extract features with the backbone | |
2. feed the extracted features into the head to | |
predicte beta, theta, twist angle, and heatmap (uvd map) | |
3. store predictions for evaluation | |
Args: | |
img (torch.Tensor): Batch of data as input. | |
img_metas (dict): Dict with image metas i.e. path | |
kwargs (dict): Dict with ground-truth | |
Returns: | |
all_preds (dict): Dict with image_path, vertices, xyz_17, uvd_jts, | |
xyz_24 for predictions. | |
""" | |
labels = {} | |
labels['trans_inv'] = kwargs['trans_inv'] | |
labels['intrinsic_param'] = kwargs['intrinsic_param'] | |
labels['joint_root'] = kwargs['joint_root'] | |
labels['depth_factor'] = kwargs['depth_factor'] | |
labels['target_uvd_29'] = kwargs['target_uvd_29'] | |
labels['target_xyz_24'] = kwargs['target_xyz_24'] | |
labels['target_weight_24'] = kwargs['target_weight_24'] | |
labels['target_weight_29'] = kwargs['target_weight_29'] | |
labels['target_xyz_17'] = kwargs['target_xyz_17'] | |
labels['target_weight_17'] = kwargs['target_weight_17'] | |
labels['target_theta'] = kwargs['target_theta'] | |
labels['target_beta'] = kwargs['target_beta'] | |
labels['target_smpl_weight'] = kwargs['target_smpl_weight'] | |
labels['target_theta_weight'] = kwargs['target_theta_weight'] | |
labels['target_twist'] = kwargs['target_twist'] | |
labels['target_twist_weight'] = kwargs['target_twist_weight'] | |
bboxes = kwargs['bbox'] | |
for k, _ in labels.items(): | |
labels[k] = labels[k].cuda() | |
trans_inv = labels.pop('trans_inv') | |
intrinsic_param = labels.pop('intrinsic_param') | |
joint_root = labels.pop('joint_root') | |
depth_factor = labels.pop('depth_factor') | |
if len(depth_factor.shape) != 2: | |
depth_factor = torch.unsqueeze(depth_factor, dim=1) | |
if self.backbone is not None: | |
img = img.cuda().requires_grad_() | |
features = self.backbone(img) | |
features = features[0] | |
else: | |
features = img['features'] | |
if self.neck is not None: | |
features = self.neck(features) | |
output = self.head(features, trans_inv, intrinsic_param, joint_root, | |
depth_factor, self.smpl) | |
pred_uvd_jts = output['pred_uvd_jts'] | |
batch_num = pred_uvd_jts.shape[0] | |
pred_xyz_jts_24 = output['pred_xyz_jts_24'].reshape(batch_num, -1, | |
3)[:, :24, :] | |
pred_xyz_jts_24_struct = output['pred_xyz_jts_24_struct'].reshape( | |
batch_num, 24, 3) | |
pred_xyz_jts_17 = output['pred_xyz_jts_17'].reshape(batch_num, 17, 3) | |
pred_mesh = output['pred_vertices'].reshape(batch_num, -1, 3) | |
pred_xyz_jts_24 = pred_xyz_jts_24.cpu().data.numpy() | |
pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.cpu().data.numpy() | |
pred_xyz_jts_17 = pred_xyz_jts_17.cpu().data.numpy() | |
pred_uvd_jts = pred_uvd_jts.cpu().data | |
pred_mesh = pred_mesh.cpu().data.numpy() | |
pred_pose = output['pred_pose'].cpu().data.numpy() | |
pred_beta = output['pred_shape'].cpu().data.numpy() | |
assert pred_xyz_jts_17.ndim in [2, 3] | |
pred_xyz_jts_17 = pred_xyz_jts_17.reshape(pred_xyz_jts_17.shape[0], 17, | |
3) | |
pred_uvd_jts = pred_uvd_jts.reshape(pred_uvd_jts.shape[0], -1, 3) | |
pred_xyz_jts_24 = pred_xyz_jts_24.reshape(pred_xyz_jts_24.shape[0], 24, | |
3) | |
pred_scores = output['maxvals'].cpu().data[:, :29] | |
hm_shape = [64, 64] | |
pose_coords_list = [] | |
for i in range(pred_xyz_jts_17.shape[0]): | |
bbox = bboxes[i].tolist() | |
pose_coords, _ = heatmap2coord(pred_uvd_jts[i], | |
pred_scores[i], | |
hm_shape, | |
bbox, | |
mean_bbox_scale=None) | |
pose_coords_list.append(pose_coords) | |
all_preds = {} | |
all_preds['vertices'] = pred_mesh | |
all_preds['smpl_pose'] = pred_pose | |
all_preds['smpl_beta'] = pred_beta | |
all_preds['xyz_17'] = pred_xyz_jts_17 | |
all_preds['uvd_jts'] = pose_coords | |
all_preds['xyz_24'] = pred_xyz_jts_24_struct | |
image_path = [] | |
for img_meta in img_metas: | |
image_path.append(img_meta['image_path']) | |
all_preds['image_path'] = image_path | |
all_preds['image_idx'] = kwargs['sample_idx'] | |
return all_preds | |