Spaces:
Running
on
L40S
Running
on
L40S
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from smplx.utils import find_joint_kin_chain | |
from detrsmpl.core.conventions.keypoints_mapping import ( | |
get_keypoint_idx, | |
get_keypoint_idxs_by_part, | |
) | |
from detrsmpl.utils.geometry import weak_perspective_projection | |
class SMPLXHandMergeFunc(): | |
"""This function use predictions from hand model to update the hand params | |
(right_hand_pose, left_hand_pose, wrist_pose) in predictions from body | |
model.""" | |
def __init__(self, body_model, convention='smplx'): | |
self.body_model = body_model | |
self.convention = convention | |
self.left_hand_idxs = get_keypoint_idxs_by_part( | |
'left_hand', self.convention) | |
self.left_wrist_idx = get_keypoint_idx('left_wrist', self.convention) | |
self.left_hand_idxs.append(self.left_wrist_idx) | |
self.left_wrist_kin_chain = find_joint_kin_chain( | |
self.left_wrist_idx, self.body_model.parents) | |
self.right_hand_idxs = get_keypoint_idxs_by_part( | |
'right_hand', self.convention) | |
self.right_wrist_idx = get_keypoint_idx('right_wrist', self.convention) | |
self.right_hand_idxs.append(self.right_wrist_idx) | |
self.right_wrist_kin_chain = find_joint_kin_chain( | |
self.right_wrist_idx, self.body_model.parents) | |
def __call__(self, body_predictions, hand_predictions): | |
"""Function | |
Args: | |
body_predictions (dict): The prediction from body model. | |
hand_predictions (dict): The prediction from hand model. | |
Returns: | |
dict: Merged prediction. | |
""" | |
pred_param = body_predictions['pred_param'] | |
global_orient = pred_param['global_orient'] | |
body_pose = pred_param['body_pose'] | |
pred_cam = body_predictions['pred_cam'] | |
batch_size = pred_cam.shape[0] | |
device = pred_cam.device | |
hands_from_body_idxs = torch.arange(0, | |
2 * batch_size, | |
dtype=torch.long, | |
device=device) | |
right_hand_from_body_idxs = hands_from_body_idxs[:batch_size] | |
left_hand_from_body_idxs = hands_from_body_idxs[batch_size:] | |
parent_rots = [] | |
right_wrist_parent_rot = find_joint_global_rotation( | |
self.right_wrist_kin_chain[1:], global_orient, body_pose) | |
left_wrist_parent_rot = find_joint_global_rotation( | |
self.left_wrist_kin_chain[1:], global_orient, body_pose) | |
left_to_right_wrist_parent_rot = flip_rotmat(left_wrist_parent_rot) | |
parent_rots += [right_wrist_parent_rot, left_to_right_wrist_parent_rot] | |
parent_rots = torch.cat(parent_rots, dim=0) | |
wrist_pose_from_hand = hand_predictions['pred_param']['global_orient'] | |
# Undo the rotation of the parent joints to make the wrist rotation | |
# relative again | |
wrist_pose_from_hand = torch.matmul( | |
parent_rots.reshape(-1, 3, 3).transpose(1, 2), | |
wrist_pose_from_hand.reshape(-1, 3, 3)) | |
right_hand_wrist = wrist_pose_from_hand[right_hand_from_body_idxs] | |
left_hand_wrist = flip_rotmat( | |
wrist_pose_from_hand[left_hand_from_body_idxs]) | |
right_hand_pose = hand_predictions['pred_param']['right_hand_pose'][ | |
right_hand_from_body_idxs] | |
left_hand_pose = flip_rotmat( | |
hand_predictions['pred_param']['right_hand_pose'] | |
[left_hand_from_body_idxs]) | |
body_predictions['pred_param']['right_hand_pose'] = right_hand_pose | |
body_predictions['pred_param']['left_hand_pose'] = left_hand_pose | |
body_predictions['pred_param']['body_pose'][:, self.right_wrist_idx - | |
1] = right_hand_wrist | |
body_predictions['pred_param']['body_pose'][:, self.left_wrist_idx - | |
1] = left_hand_wrist | |
return body_predictions | |
class SMPLXFaceMergeFunc(): | |
"""This function use predictions from face model to update the face params | |
(jaw_pose, expression) in predictions from body model.""" | |
def __init__(self, | |
body_model, | |
convention='smplx', | |
num_expression_coeffs=10): | |
self.body_model = body_model | |
self.convention = convention | |
self.num_expression_coeffs = num_expression_coeffs | |
def __call__(self, body_predictions, face_predictions): | |
"""Function | |
Args: | |
body_predictions (dict): The prediction from body model. | |
face_predictions (dict): The prediction from face model. | |
Returns: | |
dict: Merged prediction. | |
""" | |
body_predictions['pred_param']['jaw_pose'] = face_predictions[ | |
'pred_param']['jaw_pose'] | |
body_predictions['pred_param']['expression'] = face_predictions[ | |
'pred_param']['expression'][:, :self.num_expression_coeffs] | |
return body_predictions | |
def points_to_bbox(points, bbox_scale_factor: float = 1.0): | |
"""Get scaled bounding box from keypoints 2D.""" | |
min_coords, _ = torch.min(points, dim=1) | |
xmin, ymin = min_coords[:, 0], min_coords[:, 1] | |
max_coords, _ = torch.max(points, dim=1) | |
xmax, ymax = max_coords[:, 0], max_coords[:, 1] | |
center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5 | |
width = (xmax - xmin) | |
height = (ymax - ymin) | |
# Convert the bounding box to a square box | |
size = torch.max(width, height) * bbox_scale_factor | |
return center, size | |
def get_crop_info(points, | |
img_metas, | |
scale_factor: float = 1.0, | |
crop_size: int = 256): | |
"""Get the transformation of points on the cropped image to the points on | |
the original image.""" | |
device = points.device | |
dtype = points.dtype | |
batch_size = points.shape[0] | |
# Get the image to crop transformations and bounding box sizes | |
crop_transforms = [] | |
img_bbox_sizes = [] | |
for img_meta in img_metas: | |
crop_transforms.append(img_meta['crop_transform']) | |
img_bbox_sizes.append(img_meta['scale'].max()) | |
img_bbox_sizes = torch.tensor(img_bbox_sizes, dtype=dtype, device=device) | |
crop_transforms = torch.tensor(crop_transforms, dtype=dtype, device=device) | |
crop_transforms = torch.cat([ | |
crop_transforms, | |
torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device).expand( | |
[batch_size, 1, 3]) | |
], | |
dim=1) | |
inv_crop_transforms = torch.inverse(crop_transforms) | |
# center on the cropped body image | |
center_body_crop, bbox_size = points_to_bbox( | |
points, bbox_scale_factor=scale_factor) | |
orig_bbox_size = bbox_size / crop_size * img_bbox_sizes | |
# Compute the center of the crop in the original image | |
center = (torch.einsum( | |
'bij,bj->bi', [inv_crop_transforms[:, :2, :2], center_body_crop]) + | |
inv_crop_transforms[:, :2, 2]) | |
return { | |
'center': center.reshape(-1, 2), | |
'orig_bbox_size': orig_bbox_size, | |
# 'bbox_size': bbox_size.reshape(-1), | |
'inv_crop_transforms': inv_crop_transforms, | |
# 'center_body_crop': 2 * center_body_crop / (crop_size-1) - 1, | |
} | |
def concat_images(images: List[torch.Tensor]): | |
"""Concat images of different size.""" | |
sizes = [img.shape[1:] for img in images] | |
H, W = [max(s) for s in zip(*sizes)] | |
batch_size = len(images) | |
batched_shape = (batch_size, images[0].shape[0], H, W) | |
batched = torch.zeros(batched_shape, | |
device=images[0].device, | |
dtype=images[0].dtype) | |
for ii, img in enumerate(images): | |
shape = img.shape | |
batched[ii, :shape[0], :shape[1], :shape[2]] = img | |
return batched | |
def flip_rotmat(pose_rotmat): | |
"""Flip function. | |
Flip rotmat. | |
""" | |
rot_mats = pose_rotmat.reshape(-1, 9).clone() | |
rot_mats[:, [1, 2, 3, 6]] *= -1 | |
return rot_mats.view_as(pose_rotmat) | |
def find_joint_global_rotation(kin_chain, root_pose, body_pose): | |
"""Computes the absolute rotation of a joint from the kinematic chain.""" | |
# Create a single vector with all the poses | |
parents_pose = torch.cat([root_pose, body_pose], dim=1)[:, kin_chain] | |
output_pose = parents_pose[:, 0] | |
for idx in range(1, parents_pose.shape[1]): | |
output_pose = torch.bmm(parents_pose[:, idx], output_pose) | |
return output_pose | |
class CropSampler(): | |
"""This function crops the HD images using bilinear interpolation.""" | |
def __init__(self, crop_size: int = 256) -> None: | |
"""Uses bilinear sampling to extract square crops. | |
This module expects a high resolution image as input and a bounding | |
box, described by its' center and size. It then proceeds to extract | |
a sub-image using the provided information through bilinear | |
interpolation. | |
Parameters | |
---------- | |
crop_size: int | |
The desired size for the crop. | |
""" | |
super(CropSampler, self).__init__() | |
self.crop_size = crop_size | |
x = torch.arange(0, crop_size, dtype=torch.float32) / (crop_size - 1) | |
grid_y, grid_x = torch.meshgrid(x, x) | |
points = torch.stack([grid_y.flatten(), grid_x.flatten()], axis=1) | |
self.grid = points.unsqueeze(dim=0) | |
def _sample_padded(self, full_imgs, sampling_grid): | |
"""""" | |
# Get the sub-images using bilinear interpolation | |
return F.grid_sample(full_imgs, sampling_grid, align_corners=True) | |
def __call__(self, full_imgs, center, bbox_size): | |
"""Crops the HD images using the provided bounding boxes. | |
Parameters | |
---------- | |
full_imgs: ImageList | |
An image list structure with the full resolution images | |
center: torch.Tensor | |
A Bx2 tensor that contains the coordinates of the center of | |
the bounding box that will be cropped from the original | |
image | |
bbox_size: torch.Tensor | |
A size B tensor that contains the size of the corp | |
Returns | |
------- | |
cropped_images: torch.Tensoror | |
The images cropped from the high resolution input | |
sampling_grid: torch.Tensor | |
The grid used to sample the crops | |
""" | |
batch_size, _, H, W = full_imgs.shape | |
self.grid = self.grid.to(device=full_imgs.device) | |
transforms = torch.eye(3, | |
dtype=full_imgs.dtype, | |
device=full_imgs.device).reshape( | |
1, 3, 3).expand(batch_size, -1, | |
-1).contiguous() | |
hd_to_crop = torch.eye(3, | |
dtype=full_imgs.dtype, | |
device=full_imgs.device).reshape( | |
1, 3, 3).expand(batch_size, -1, | |
-1).contiguous() | |
# Create the transformation that maps crop pixels to image coordinates, | |
# i.e. pixel (0, 0) from the crop_size x crop_size grid gets mapped to | |
# the top left of the bounding box, pixel | |
# (crop_size - 1, crop_size - 1) to the bottom right corner of the | |
# bounding box | |
transforms[:, 0, 0] = bbox_size # / (self.crop_size - 1) | |
transforms[:, 1, 1] = bbox_size # / (self.crop_size - 1) | |
transforms[:, 0, 2] = center[:, 0] - bbox_size * 0.5 | |
transforms[:, 1, 2] = center[:, 1] - bbox_size * 0.5 | |
hd_to_crop[:, 0, 0] = 2 * (self.crop_size - 1) / bbox_size | |
hd_to_crop[:, 1, 1] = 2 * (self.crop_size - 1) / bbox_size | |
hd_to_crop[:, 0, | |
2] = -(center[:, 0] - bbox_size * 0.5) * hd_to_crop[:, 0, | |
0] - 1 | |
hd_to_crop[:, 1, | |
2] = -(center[:, 1] - bbox_size * 0.5) * hd_to_crop[:, 1, | |
1] - 1 | |
size_bbox_sizer = torch.eye(3, | |
dtype=full_imgs.dtype, | |
device=full_imgs.device).reshape( | |
1, 3, 3).expand(batch_size, -1, | |
-1).contiguous() | |
# Normalize the coordinates to [-1, 1] for the grid_sample function | |
size_bbox_sizer[:, 0, 0] = 2.0 / (W - 1) | |
size_bbox_sizer[:, 1, 1] = 2.0 / (H - 1) | |
size_bbox_sizer[:, :2, 2] = -1 | |
# full_transform = transforms | |
full_transform = torch.bmm(size_bbox_sizer, transforms) | |
batch_grid = self.grid.expand(batch_size, -1, -1) | |
# Convert the grid to image coordinates using the transformations above | |
sampling_grid = ( | |
torch.bmm(full_transform[:, :2, :2], batch_grid.transpose(1, 2)) + | |
full_transform[:, :2, [2]]).transpose(1, 2) | |
sampling_grid = sampling_grid.reshape(-1, self.crop_size, | |
self.crop_size, | |
2).transpose(1, 2) | |
out_images = self._sample_padded(full_imgs, sampling_grid) | |
return { | |
'images': out_images, | |
'sampling_grid': sampling_grid.reshape(batch_size, -1, 2), | |
'transform': transforms, | |
'hd_to_crop': hd_to_crop, | |
} | |
class SMPLXHandCropFunc(): | |
"""This function crop hand image from the original image. | |
Use the output keypoints predicted by the body model to locate the hand | |
position. | |
""" | |
def __init__(self, | |
model_head, | |
body_model, | |
convention='smplx', | |
img_res=256, | |
scale_factor=2.0, | |
crop_size=224, | |
condition_hand_wrist_pose=True, | |
condition_hand_shape=False, | |
condition_hand_finger_pose=True): | |
self.model_head = model_head | |
self.body_model = body_model | |
self.img_res = img_res | |
self.convention = convention | |
self.left_hand_idxs = get_keypoint_idxs_by_part( | |
'left_hand', self.convention) | |
left_wrist_idx = get_keypoint_idx('left_wrist', self.convention) | |
self.left_hand_idxs.append(left_wrist_idx) | |
self.left_wrist_kin_chain = find_joint_kin_chain( | |
left_wrist_idx, self.body_model.parents) | |
self.right_hand_idxs = get_keypoint_idxs_by_part( | |
'right_hand', self.convention) | |
right_wrist_idx = get_keypoint_idx('right_wrist', self.convention) | |
self.right_hand_idxs.append(right_wrist_idx) | |
self.right_wrist_kin_chain = find_joint_kin_chain( | |
right_wrist_idx, self.body_model.parents) | |
self.scale_factor = scale_factor | |
self.hand_cropper = CropSampler(crop_size) | |
self.condition_hand_wrist_pose = condition_hand_wrist_pose | |
self.condition_hand_shape = condition_hand_shape | |
self.condition_hand_finger_pose = condition_hand_finger_pose | |
def build_hand_mean(self, global_orient, body_pose, betas, left_hand_pose, | |
raw_right_hand_pose, batch_size): | |
"""Builds the initial point for the iterative regressor of the hand.""" | |
hand_mean = [] | |
# if self.condition_hand_on_body: | |
# Convert the absolute pose to the latent representation | |
if self.condition_hand_wrist_pose: | |
# Compute the absolute pose of the right wrist | |
right_wrist_pose_abs = find_joint_global_rotation( | |
self.right_wrist_kin_chain, global_orient, body_pose) | |
right_wrist_pose = right_wrist_pose_abs[:, :3, :2].contiguous( | |
).reshape(batch_size, -1) | |
# Compute the absolute rotation for the left wrist | |
left_wrist_pose_abs = find_joint_global_rotation( | |
self.left_wrist_kin_chain, global_orient, body_pose) | |
# Flip the left wrist to the right | |
left_to_right_wrist_pose = flip_rotmat(left_wrist_pose_abs) | |
# Convert to the latent representation | |
left_to_right_wrist_pose = left_to_right_wrist_pose[:, :3, : | |
2].contiguous( | |
).reshape( | |
batch_size, | |
-1) | |
else: | |
right_wrist_pose = self.model_head.get_mean('global_orient', | |
batch_size=batch_size) | |
left_to_right_wrist_pose = self.model_head.get_mean( | |
'global_orient', batch_size=batch_size) | |
# Convert the pose of the left hand to the right hand and project | |
# it to the encoder space | |
left_to_right_hand_pose = flip_rotmat( | |
left_hand_pose)[:, :, :3, :2].contiguous().reshape(batch_size, -1) | |
right_hand_pose = raw_right_hand_pose.reshape(batch_size, -1) | |
camera_mean = self.model_head.get_mean('camera', batch_size=batch_size) | |
shape_condition = (betas if self.condition_hand_shape else | |
self.model_head.get_mean('shape', | |
batch_size=batch_size)) | |
right_finger_pose_condition = ( | |
right_hand_pose if self.condition_hand_finger_pose else | |
self.model_head.get_mean('right_hand_pose', batch_size=batch_size)) | |
right_hand_mean = torch.cat([ | |
right_wrist_pose, right_finger_pose_condition, shape_condition, | |
camera_mean | |
], | |
dim=1) | |
left_finger_pose_condition = ( | |
left_to_right_hand_pose if self.condition_hand_finger_pose else | |
self.model_head.get_mean('right_hand_pose', batch_size=batch_size)) | |
# Should be Bx31 | |
left_hand_mean = torch.cat([ | |
left_to_right_wrist_pose, left_finger_pose_condition, | |
shape_condition, camera_mean | |
], | |
dim=1) | |
hand_mean += [right_hand_mean, left_hand_mean] | |
hand_mean = torch.cat(hand_mean, dim=0) | |
return hand_mean | |
def __call__(self, body_predictions, img_metas): | |
"""Function | |
Args: | |
body_predictions (dict): The prediction from body model. | |
img_metas (dict): Information of the input images. | |
Returns: | |
all_hand_imgs (torch.tensor): Cropped hand images. | |
hand_mean (torch.tensor): Mean value of hand params. | |
crop_info (dict): Hand crop transforms. | |
""" | |
pred_param = body_predictions['pred_param'] | |
pred_cam = body_predictions['pred_cam'] | |
pred_raw = body_predictions['pred_raw'] | |
pred_output = self.body_model(**pred_param) | |
pred_keypoints3d = pred_output['joints'] | |
pred_keypoints2d = weak_perspective_projection( | |
pred_keypoints3d, | |
scale=pred_cam[:, 0], | |
translation=pred_cam[:, 1:3]) | |
# concat ori_img | |
full_images = [] | |
for img_meta in img_metas: | |
full_images.append(img_meta['ori_img'].to(device=pred_cam.device)) | |
full_imgs = concat_images(full_images) | |
# left hand | |
left_hand_joints = (pred_keypoints2d[:, self.left_hand_idxs] * 0.5 + | |
0.5) * (self.img_res - 1) | |
left_hand_points_to_crop = get_crop_info(left_hand_joints, img_metas, | |
self.scale_factor, | |
self.img_res) | |
left_hand_center = left_hand_points_to_crop['center'] | |
left_hand_orig_bbox_size = left_hand_points_to_crop['orig_bbox_size'] | |
left_hand_inv_crop_transforms = left_hand_points_to_crop[ | |
'inv_crop_transforms'] | |
left_hand_cropper_out = self.hand_cropper(full_imgs, left_hand_center, | |
left_hand_orig_bbox_size) | |
left_hand_crops = left_hand_cropper_out['images'] | |
# left_hand_points = left_hand_cropper_out['sampling_grid'] | |
left_hand_crop_transform = left_hand_cropper_out['transform'] | |
# right hand | |
right_hand_joints = (pred_keypoints2d[:, self.right_hand_idxs] * 0.5 + | |
0.5) * (self.img_res - 1) | |
right_hand_points_to_crop = get_crop_info(right_hand_joints, img_metas, | |
self.scale_factor, | |
self.img_res) | |
right_hand_center = right_hand_points_to_crop['center'] | |
right_hand_orig_bbox_size = right_hand_points_to_crop['orig_bbox_size'] | |
# right_hand_inv_crop_transforms = right_hand_points_to_crop[ | |
# 'inv_crop_transforms'] | |
right_hand_cropper_out = self.hand_cropper(full_imgs, | |
right_hand_center, | |
right_hand_orig_bbox_size) | |
right_hand_crops = right_hand_cropper_out['images'] | |
# right_hand_points = right_hand_cropper_out['sampling_grid'] | |
right_hand_crop_transform = right_hand_cropper_out['transform'] | |
# concat | |
all_hand_imgs = [] | |
all_hand_imgs.append(right_hand_crops) | |
all_hand_imgs.append(torch.flip(left_hand_crops, dims=(-1, ))) | |
# [right_hand , left hand] | |
all_hand_imgs = torch.cat(all_hand_imgs, dim=0) | |
hand_mean = self.build_hand_mean(pred_param['global_orient'], | |
pred_param['body_pose'], | |
pred_param['betas'], | |
pred_param['left_hand_pose'], | |
pred_raw['raw_right_hand_pose'], | |
batch_size=full_imgs.shape[0]) | |
crop_info = dict( | |
hand_inv_crop_transforms=left_hand_inv_crop_transforms, | |
left_hand_crop_transform=left_hand_crop_transform, | |
right_hand_crop_transform=right_hand_crop_transform) | |
return all_hand_imgs, hand_mean, crop_info | |
class SMPLXFaceCropFunc(): | |
"""This function crop face image from the original image. | |
Use the output keypoints predicted by the facce model to locate the face | |
position. | |
""" | |
def __init__(self, | |
model_head, | |
body_model, | |
convention='smplx', | |
img_res=256, | |
scale_factor=2.0, | |
crop_size=256, | |
num_betas=10, | |
num_expression_coeffs=10, | |
condition_face_neck_pose=False, | |
condition_face_jaw_pose=True, | |
condition_face_shape=False, | |
condition_face_expression=True): | |
self.model_head = model_head | |
self.body_model = body_model | |
self.img_res = img_res | |
self.convention = convention | |
self.num_betas = num_betas | |
self.num_expression_coeffs = num_expression_coeffs | |
self.face_idx = get_keypoint_idxs_by_part('head', self.convention) | |
neck_idx = get_keypoint_idx('neck', self.convention) | |
self.neck_kin_chain = find_joint_kin_chain(neck_idx, | |
self.body_model.parents) | |
self.condition_face_neck_pose = condition_face_neck_pose | |
self.condition_face_jaw_pose = condition_face_jaw_pose | |
self.condition_face_shape = condition_face_shape | |
self.condition_face_expression = condition_face_expression | |
self.scale_factor = scale_factor | |
self.face_cropper = CropSampler(crop_size) | |
def build_face_mean(self, global_orient, body_pose, betas, raw_jaw_pose, | |
expression, batch_size): | |
"""Builds the initial point for the iterative regressor of the face.""" | |
face_mean = [] | |
# Compute the absolute pose of the right wrist | |
neck_pose_abs = find_joint_global_rotation(self.neck_kin_chain, | |
global_orient, body_pose) | |
# Convert the absolute neck pose to offsets | |
neck_pose = neck_pose_abs[:, :3, :2].contiguous().reshape( | |
batch_size, -1) | |
camera_mean = self.model_head.get_mean('camera', batch_size=batch_size) | |
neck_pose_condition = (neck_pose if self.condition_face_neck_pose else | |
self.model_head.get_mean('global_orient', | |
batch_size=batch_size)) | |
jaw_pose_condition = (raw_jaw_pose.reshape(batch_size, -1) | |
if self.condition_face_jaw_pose else | |
self.model_head.get_mean('jaw_pose', | |
batch_size=batch_size)) | |
face_num_betas = self.model_head.get_num_betas() | |
shape_padding_size = face_num_betas - self.num_betas | |
betas_condition = ( | |
F.pad(betas.reshape(batch_size, -1), | |
(0, shape_padding_size)) if self.condition_face_shape else | |
self.model_head.get_mean('shape', batch_size=batch_size)) | |
face_num_expression_coeffs = self.model_head.get_num_expression_coeffs( | |
) | |
expr_padding_size = face_num_expression_coeffs \ | |
- self.num_expression_coeffs | |
expression_condition = ( | |
F.pad(expression.reshape(batch_size, -1), | |
(0, expr_padding_size)) if self.condition_face_expression | |
else self.model_head.get_mean('expression', batch_size=batch_size)) | |
# Should be Bx(Head pose params) | |
face_mean.append( | |
torch.cat([ | |
neck_pose_condition, | |
jaw_pose_condition, | |
betas_condition, | |
expression_condition, | |
camera_mean.reshape(batch_size, -1), | |
], | |
dim=1)) | |
face_mean = torch.cat(face_mean, dim=0) | |
return face_mean | |
def __call__(self, body_predictions, img_metas): | |
"""Function | |
Args: | |
body_predictions (dict): The prediction from body model. | |
img_metas (dict): Information of the input images. | |
Returns: | |
all_face_imgs (torch.tensor): Cropped face images. | |
face_mean (torch.tensor): Mean value of face params. | |
crop_info (dict): Face crop transforms. | |
""" | |
pred_param = body_predictions['pred_param'] | |
pred_cam = body_predictions['pred_cam'] | |
pred_raw = body_predictions['pred_raw'] | |
pred_output = self.body_model(**pred_param) | |
pred_keypoints3d = pred_output['joints'] | |
pred_keypoints2d = weak_perspective_projection( | |
pred_keypoints3d, | |
scale=pred_cam[:, 0], | |
translation=pred_cam[:, 1:3]) | |
# concat ori_img | |
full_images = [] | |
for img_meta in img_metas: | |
full_images.append(img_meta['ori_img'].to(device=pred_cam.device)) | |
full_imgs = concat_images(full_images) | |
face_joints = (pred_keypoints2d[:, self.face_idx] * 0.5 + | |
0.5) * (self.img_res - 1) | |
face_points_to_crop = get_crop_info(face_joints, img_metas, | |
self.scale_factor, self.img_res) | |
face_center = face_points_to_crop['center'] | |
face_orig_bbox_size = face_points_to_crop['orig_bbox_size'] | |
face_inv_crop_transforms = face_points_to_crop['inv_crop_transforms'] | |
face_cropper_out = self.face_cropper(full_imgs, face_center, | |
face_orig_bbox_size) | |
face_crops = face_cropper_out['images'] | |
# face_points = face_cropper_out['sampling_grid'] | |
face_crop_transform = face_cropper_out['transform'] | |
all_face_imgs = [face_crops] | |
all_face_imgs = torch.cat(all_face_imgs, dim=0) | |
face_mean = self.build_face_mean(pred_param['global_orient'], | |
pred_param['body_pose'], | |
pred_param['betas'], | |
pred_raw['raw_jaw_pose'], | |
pred_param['expression'], | |
batch_size=full_imgs.shape[0]) | |
crop_info = dict(face_inv_crop_transforms=face_inv_crop_transforms, | |
face_crop_transform=face_crop_transform) | |
return all_face_imgs, face_mean, crop_info | |