from typing import List import torch import torch.nn.functional as F from smplx.utils import find_joint_kin_chain from detrsmpl.core.conventions.keypoints_mapping import ( get_keypoint_idx, get_keypoint_idxs_by_part, ) from detrsmpl.utils.geometry import weak_perspective_projection class SMPLXHandMergeFunc(): """This function use predictions from hand model to update the hand params (right_hand_pose, left_hand_pose, wrist_pose) in predictions from body model.""" def __init__(self, body_model, convention='smplx'): self.body_model = body_model self.convention = convention self.left_hand_idxs = get_keypoint_idxs_by_part( 'left_hand', self.convention) self.left_wrist_idx = get_keypoint_idx('left_wrist', self.convention) self.left_hand_idxs.append(self.left_wrist_idx) self.left_wrist_kin_chain = find_joint_kin_chain( self.left_wrist_idx, self.body_model.parents) self.right_hand_idxs = get_keypoint_idxs_by_part( 'right_hand', self.convention) self.right_wrist_idx = get_keypoint_idx('right_wrist', self.convention) self.right_hand_idxs.append(self.right_wrist_idx) self.right_wrist_kin_chain = find_joint_kin_chain( self.right_wrist_idx, self.body_model.parents) def __call__(self, body_predictions, hand_predictions): """Function Args: body_predictions (dict): The prediction from body model. hand_predictions (dict): The prediction from hand model. Returns: dict: Merged prediction. """ pred_param = body_predictions['pred_param'] global_orient = pred_param['global_orient'] body_pose = pred_param['body_pose'] pred_cam = body_predictions['pred_cam'] batch_size = pred_cam.shape[0] device = pred_cam.device hands_from_body_idxs = torch.arange(0, 2 * batch_size, dtype=torch.long, device=device) right_hand_from_body_idxs = hands_from_body_idxs[:batch_size] left_hand_from_body_idxs = hands_from_body_idxs[batch_size:] parent_rots = [] right_wrist_parent_rot = find_joint_global_rotation( self.right_wrist_kin_chain[1:], global_orient, body_pose) left_wrist_parent_rot = find_joint_global_rotation( self.left_wrist_kin_chain[1:], global_orient, body_pose) left_to_right_wrist_parent_rot = flip_rotmat(left_wrist_parent_rot) parent_rots += [right_wrist_parent_rot, left_to_right_wrist_parent_rot] parent_rots = torch.cat(parent_rots, dim=0) wrist_pose_from_hand = hand_predictions['pred_param']['global_orient'] # Undo the rotation of the parent joints to make the wrist rotation # relative again wrist_pose_from_hand = torch.matmul( parent_rots.reshape(-1, 3, 3).transpose(1, 2), wrist_pose_from_hand.reshape(-1, 3, 3)) right_hand_wrist = wrist_pose_from_hand[right_hand_from_body_idxs] left_hand_wrist = flip_rotmat( wrist_pose_from_hand[left_hand_from_body_idxs]) right_hand_pose = hand_predictions['pred_param']['right_hand_pose'][ right_hand_from_body_idxs] left_hand_pose = flip_rotmat( hand_predictions['pred_param']['right_hand_pose'] [left_hand_from_body_idxs]) body_predictions['pred_param']['right_hand_pose'] = right_hand_pose body_predictions['pred_param']['left_hand_pose'] = left_hand_pose body_predictions['pred_param']['body_pose'][:, self.right_wrist_idx - 1] = right_hand_wrist body_predictions['pred_param']['body_pose'][:, self.left_wrist_idx - 1] = left_hand_wrist return body_predictions class SMPLXFaceMergeFunc(): """This function use predictions from face model to update the face params (jaw_pose, expression) in predictions from body model.""" def __init__(self, body_model, convention='smplx', num_expression_coeffs=10): self.body_model = body_model self.convention = convention self.num_expression_coeffs = num_expression_coeffs def __call__(self, body_predictions, face_predictions): """Function Args: body_predictions (dict): The prediction from body model. face_predictions (dict): The prediction from face model. Returns: dict: Merged prediction. """ body_predictions['pred_param']['jaw_pose'] = face_predictions[ 'pred_param']['jaw_pose'] body_predictions['pred_param']['expression'] = face_predictions[ 'pred_param']['expression'][:, :self.num_expression_coeffs] return body_predictions def points_to_bbox(points, bbox_scale_factor: float = 1.0): """Get scaled bounding box from keypoints 2D.""" min_coords, _ = torch.min(points, dim=1) xmin, ymin = min_coords[:, 0], min_coords[:, 1] max_coords, _ = torch.max(points, dim=1) xmax, ymax = max_coords[:, 0], max_coords[:, 1] center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5 width = (xmax - xmin) height = (ymax - ymin) # Convert the bounding box to a square box size = torch.max(width, height) * bbox_scale_factor return center, size def get_crop_info(points, img_metas, scale_factor: float = 1.0, crop_size: int = 256): """Get the transformation of points on the cropped image to the points on the original image.""" device = points.device dtype = points.dtype batch_size = points.shape[0] # Get the image to crop transformations and bounding box sizes crop_transforms = [] img_bbox_sizes = [] for img_meta in img_metas: crop_transforms.append(img_meta['crop_transform']) img_bbox_sizes.append(img_meta['scale'].max()) img_bbox_sizes = torch.tensor(img_bbox_sizes, dtype=dtype, device=device) crop_transforms = torch.tensor(crop_transforms, dtype=dtype, device=device) crop_transforms = torch.cat([ crop_transforms, torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device).expand( [batch_size, 1, 3]) ], dim=1) inv_crop_transforms = torch.inverse(crop_transforms) # center on the cropped body image center_body_crop, bbox_size = points_to_bbox( points, bbox_scale_factor=scale_factor) orig_bbox_size = bbox_size / crop_size * img_bbox_sizes # Compute the center of the crop in the original image center = (torch.einsum( 'bij,bj->bi', [inv_crop_transforms[:, :2, :2], center_body_crop]) + inv_crop_transforms[:, :2, 2]) return { 'center': center.reshape(-1, 2), 'orig_bbox_size': orig_bbox_size, # 'bbox_size': bbox_size.reshape(-1), 'inv_crop_transforms': inv_crop_transforms, # 'center_body_crop': 2 * center_body_crop / (crop_size-1) - 1, } def concat_images(images: List[torch.Tensor]): """Concat images of different size.""" sizes = [img.shape[1:] for img in images] H, W = [max(s) for s in zip(*sizes)] batch_size = len(images) batched_shape = (batch_size, images[0].shape[0], H, W) batched = torch.zeros(batched_shape, device=images[0].device, dtype=images[0].dtype) for ii, img in enumerate(images): shape = img.shape batched[ii, :shape[0], :shape[1], :shape[2]] = img return batched def flip_rotmat(pose_rotmat): """Flip function. Flip rotmat. """ rot_mats = pose_rotmat.reshape(-1, 9).clone() rot_mats[:, [1, 2, 3, 6]] *= -1 return rot_mats.view_as(pose_rotmat) def find_joint_global_rotation(kin_chain, root_pose, body_pose): """Computes the absolute rotation of a joint from the kinematic chain.""" # Create a single vector with all the poses parents_pose = torch.cat([root_pose, body_pose], dim=1)[:, kin_chain] output_pose = parents_pose[:, 0] for idx in range(1, parents_pose.shape[1]): output_pose = torch.bmm(parents_pose[:, idx], output_pose) return output_pose class CropSampler(): """This function crops the HD images using bilinear interpolation.""" def __init__(self, crop_size: int = 256) -> None: """Uses bilinear sampling to extract square crops. This module expects a high resolution image as input and a bounding box, described by its' center and size. It then proceeds to extract a sub-image using the provided information through bilinear interpolation. Parameters ---------- crop_size: int The desired size for the crop. """ super(CropSampler, self).__init__() self.crop_size = crop_size x = torch.arange(0, crop_size, dtype=torch.float32) / (crop_size - 1) grid_y, grid_x = torch.meshgrid(x, x) points = torch.stack([grid_y.flatten(), grid_x.flatten()], axis=1) self.grid = points.unsqueeze(dim=0) def _sample_padded(self, full_imgs, sampling_grid): """""" # Get the sub-images using bilinear interpolation return F.grid_sample(full_imgs, sampling_grid, align_corners=True) def __call__(self, full_imgs, center, bbox_size): """Crops the HD images using the provided bounding boxes. Parameters ---------- full_imgs: ImageList An image list structure with the full resolution images center: torch.Tensor A Bx2 tensor that contains the coordinates of the center of the bounding box that will be cropped from the original image bbox_size: torch.Tensor A size B tensor that contains the size of the corp Returns ------- cropped_images: torch.Tensoror The images cropped from the high resolution input sampling_grid: torch.Tensor The grid used to sample the crops """ batch_size, _, H, W = full_imgs.shape self.grid = self.grid.to(device=full_imgs.device) transforms = torch.eye(3, dtype=full_imgs.dtype, device=full_imgs.device).reshape( 1, 3, 3).expand(batch_size, -1, -1).contiguous() hd_to_crop = torch.eye(3, dtype=full_imgs.dtype, device=full_imgs.device).reshape( 1, 3, 3).expand(batch_size, -1, -1).contiguous() # Create the transformation that maps crop pixels to image coordinates, # i.e. pixel (0, 0) from the crop_size x crop_size grid gets mapped to # the top left of the bounding box, pixel # (crop_size - 1, crop_size - 1) to the bottom right corner of the # bounding box transforms[:, 0, 0] = bbox_size # / (self.crop_size - 1) transforms[:, 1, 1] = bbox_size # / (self.crop_size - 1) transforms[:, 0, 2] = center[:, 0] - bbox_size * 0.5 transforms[:, 1, 2] = center[:, 1] - bbox_size * 0.5 hd_to_crop[:, 0, 0] = 2 * (self.crop_size - 1) / bbox_size hd_to_crop[:, 1, 1] = 2 * (self.crop_size - 1) / bbox_size hd_to_crop[:, 0, 2] = -(center[:, 0] - bbox_size * 0.5) * hd_to_crop[:, 0, 0] - 1 hd_to_crop[:, 1, 2] = -(center[:, 1] - bbox_size * 0.5) * hd_to_crop[:, 1, 1] - 1 size_bbox_sizer = torch.eye(3, dtype=full_imgs.dtype, device=full_imgs.device).reshape( 1, 3, 3).expand(batch_size, -1, -1).contiguous() # Normalize the coordinates to [-1, 1] for the grid_sample function size_bbox_sizer[:, 0, 0] = 2.0 / (W - 1) size_bbox_sizer[:, 1, 1] = 2.0 / (H - 1) size_bbox_sizer[:, :2, 2] = -1 # full_transform = transforms full_transform = torch.bmm(size_bbox_sizer, transforms) batch_grid = self.grid.expand(batch_size, -1, -1) # Convert the grid to image coordinates using the transformations above sampling_grid = ( torch.bmm(full_transform[:, :2, :2], batch_grid.transpose(1, 2)) + full_transform[:, :2, [2]]).transpose(1, 2) sampling_grid = sampling_grid.reshape(-1, self.crop_size, self.crop_size, 2).transpose(1, 2) out_images = self._sample_padded(full_imgs, sampling_grid) return { 'images': out_images, 'sampling_grid': sampling_grid.reshape(batch_size, -1, 2), 'transform': transforms, 'hd_to_crop': hd_to_crop, } class SMPLXHandCropFunc(): """This function crop hand image from the original image. Use the output keypoints predicted by the body model to locate the hand position. """ def __init__(self, model_head, body_model, convention='smplx', img_res=256, scale_factor=2.0, crop_size=224, condition_hand_wrist_pose=True, condition_hand_shape=False, condition_hand_finger_pose=True): self.model_head = model_head self.body_model = body_model self.img_res = img_res self.convention = convention self.left_hand_idxs = get_keypoint_idxs_by_part( 'left_hand', self.convention) left_wrist_idx = get_keypoint_idx('left_wrist', self.convention) self.left_hand_idxs.append(left_wrist_idx) self.left_wrist_kin_chain = find_joint_kin_chain( left_wrist_idx, self.body_model.parents) self.right_hand_idxs = get_keypoint_idxs_by_part( 'right_hand', self.convention) right_wrist_idx = get_keypoint_idx('right_wrist', self.convention) self.right_hand_idxs.append(right_wrist_idx) self.right_wrist_kin_chain = find_joint_kin_chain( right_wrist_idx, self.body_model.parents) self.scale_factor = scale_factor self.hand_cropper = CropSampler(crop_size) self.condition_hand_wrist_pose = condition_hand_wrist_pose self.condition_hand_shape = condition_hand_shape self.condition_hand_finger_pose = condition_hand_finger_pose def build_hand_mean(self, global_orient, body_pose, betas, left_hand_pose, raw_right_hand_pose, batch_size): """Builds the initial point for the iterative regressor of the hand.""" hand_mean = [] # if self.condition_hand_on_body: # Convert the absolute pose to the latent representation if self.condition_hand_wrist_pose: # Compute the absolute pose of the right wrist right_wrist_pose_abs = find_joint_global_rotation( self.right_wrist_kin_chain, global_orient, body_pose) right_wrist_pose = right_wrist_pose_abs[:, :3, :2].contiguous( ).reshape(batch_size, -1) # Compute the absolute rotation for the left wrist left_wrist_pose_abs = find_joint_global_rotation( self.left_wrist_kin_chain, global_orient, body_pose) # Flip the left wrist to the right left_to_right_wrist_pose = flip_rotmat(left_wrist_pose_abs) # Convert to the latent representation left_to_right_wrist_pose = left_to_right_wrist_pose[:, :3, : 2].contiguous( ).reshape( batch_size, -1) else: right_wrist_pose = self.model_head.get_mean('global_orient', batch_size=batch_size) left_to_right_wrist_pose = self.model_head.get_mean( 'global_orient', batch_size=batch_size) # Convert the pose of the left hand to the right hand and project # it to the encoder space left_to_right_hand_pose = flip_rotmat( left_hand_pose)[:, :, :3, :2].contiguous().reshape(batch_size, -1) right_hand_pose = raw_right_hand_pose.reshape(batch_size, -1) camera_mean = self.model_head.get_mean('camera', batch_size=batch_size) shape_condition = (betas if self.condition_hand_shape else self.model_head.get_mean('shape', batch_size=batch_size)) right_finger_pose_condition = ( right_hand_pose if self.condition_hand_finger_pose else self.model_head.get_mean('right_hand_pose', batch_size=batch_size)) right_hand_mean = torch.cat([ right_wrist_pose, right_finger_pose_condition, shape_condition, camera_mean ], dim=1) left_finger_pose_condition = ( left_to_right_hand_pose if self.condition_hand_finger_pose else self.model_head.get_mean('right_hand_pose', batch_size=batch_size)) # Should be Bx31 left_hand_mean = torch.cat([ left_to_right_wrist_pose, left_finger_pose_condition, shape_condition, camera_mean ], dim=1) hand_mean += [right_hand_mean, left_hand_mean] hand_mean = torch.cat(hand_mean, dim=0) return hand_mean def __call__(self, body_predictions, img_metas): """Function Args: body_predictions (dict): The prediction from body model. img_metas (dict): Information of the input images. Returns: all_hand_imgs (torch.tensor): Cropped hand images. hand_mean (torch.tensor): Mean value of hand params. crop_info (dict): Hand crop transforms. """ pred_param = body_predictions['pred_param'] pred_cam = body_predictions['pred_cam'] pred_raw = body_predictions['pred_raw'] pred_output = self.body_model(**pred_param) pred_keypoints3d = pred_output['joints'] pred_keypoints2d = weak_perspective_projection( pred_keypoints3d, scale=pred_cam[:, 0], translation=pred_cam[:, 1:3]) # concat ori_img full_images = [] for img_meta in img_metas: full_images.append(img_meta['ori_img'].to(device=pred_cam.device)) full_imgs = concat_images(full_images) # left hand left_hand_joints = (pred_keypoints2d[:, self.left_hand_idxs] * 0.5 + 0.5) * (self.img_res - 1) left_hand_points_to_crop = get_crop_info(left_hand_joints, img_metas, self.scale_factor, self.img_res) left_hand_center = left_hand_points_to_crop['center'] left_hand_orig_bbox_size = left_hand_points_to_crop['orig_bbox_size'] left_hand_inv_crop_transforms = left_hand_points_to_crop[ 'inv_crop_transforms'] left_hand_cropper_out = self.hand_cropper(full_imgs, left_hand_center, left_hand_orig_bbox_size) left_hand_crops = left_hand_cropper_out['images'] # left_hand_points = left_hand_cropper_out['sampling_grid'] left_hand_crop_transform = left_hand_cropper_out['transform'] # right hand right_hand_joints = (pred_keypoints2d[:, self.right_hand_idxs] * 0.5 + 0.5) * (self.img_res - 1) right_hand_points_to_crop = get_crop_info(right_hand_joints, img_metas, self.scale_factor, self.img_res) right_hand_center = right_hand_points_to_crop['center'] right_hand_orig_bbox_size = right_hand_points_to_crop['orig_bbox_size'] # right_hand_inv_crop_transforms = right_hand_points_to_crop[ # 'inv_crop_transforms'] right_hand_cropper_out = self.hand_cropper(full_imgs, right_hand_center, right_hand_orig_bbox_size) right_hand_crops = right_hand_cropper_out['images'] # right_hand_points = right_hand_cropper_out['sampling_grid'] right_hand_crop_transform = right_hand_cropper_out['transform'] # concat all_hand_imgs = [] all_hand_imgs.append(right_hand_crops) all_hand_imgs.append(torch.flip(left_hand_crops, dims=(-1, ))) # [right_hand , left hand] all_hand_imgs = torch.cat(all_hand_imgs, dim=0) hand_mean = self.build_hand_mean(pred_param['global_orient'], pred_param['body_pose'], pred_param['betas'], pred_param['left_hand_pose'], pred_raw['raw_right_hand_pose'], batch_size=full_imgs.shape[0]) crop_info = dict( hand_inv_crop_transforms=left_hand_inv_crop_transforms, left_hand_crop_transform=left_hand_crop_transform, right_hand_crop_transform=right_hand_crop_transform) return all_hand_imgs, hand_mean, crop_info class SMPLXFaceCropFunc(): """This function crop face image from the original image. Use the output keypoints predicted by the facce model to locate the face position. """ def __init__(self, model_head, body_model, convention='smplx', img_res=256, scale_factor=2.0, crop_size=256, num_betas=10, num_expression_coeffs=10, condition_face_neck_pose=False, condition_face_jaw_pose=True, condition_face_shape=False, condition_face_expression=True): self.model_head = model_head self.body_model = body_model self.img_res = img_res self.convention = convention self.num_betas = num_betas self.num_expression_coeffs = num_expression_coeffs self.face_idx = get_keypoint_idxs_by_part('head', self.convention) neck_idx = get_keypoint_idx('neck', self.convention) self.neck_kin_chain = find_joint_kin_chain(neck_idx, self.body_model.parents) self.condition_face_neck_pose = condition_face_neck_pose self.condition_face_jaw_pose = condition_face_jaw_pose self.condition_face_shape = condition_face_shape self.condition_face_expression = condition_face_expression self.scale_factor = scale_factor self.face_cropper = CropSampler(crop_size) def build_face_mean(self, global_orient, body_pose, betas, raw_jaw_pose, expression, batch_size): """Builds the initial point for the iterative regressor of the face.""" face_mean = [] # Compute the absolute pose of the right wrist neck_pose_abs = find_joint_global_rotation(self.neck_kin_chain, global_orient, body_pose) # Convert the absolute neck pose to offsets neck_pose = neck_pose_abs[:, :3, :2].contiguous().reshape( batch_size, -1) camera_mean = self.model_head.get_mean('camera', batch_size=batch_size) neck_pose_condition = (neck_pose if self.condition_face_neck_pose else self.model_head.get_mean('global_orient', batch_size=batch_size)) jaw_pose_condition = (raw_jaw_pose.reshape(batch_size, -1) if self.condition_face_jaw_pose else self.model_head.get_mean('jaw_pose', batch_size=batch_size)) face_num_betas = self.model_head.get_num_betas() shape_padding_size = face_num_betas - self.num_betas betas_condition = ( F.pad(betas.reshape(batch_size, -1), (0, shape_padding_size)) if self.condition_face_shape else self.model_head.get_mean('shape', batch_size=batch_size)) face_num_expression_coeffs = self.model_head.get_num_expression_coeffs( ) expr_padding_size = face_num_expression_coeffs \ - self.num_expression_coeffs expression_condition = ( F.pad(expression.reshape(batch_size, -1), (0, expr_padding_size)) if self.condition_face_expression else self.model_head.get_mean('expression', batch_size=batch_size)) # Should be Bx(Head pose params) face_mean.append( torch.cat([ neck_pose_condition, jaw_pose_condition, betas_condition, expression_condition, camera_mean.reshape(batch_size, -1), ], dim=1)) face_mean = torch.cat(face_mean, dim=0) return face_mean def __call__(self, body_predictions, img_metas): """Function Args: body_predictions (dict): The prediction from body model. img_metas (dict): Information of the input images. Returns: all_face_imgs (torch.tensor): Cropped face images. face_mean (torch.tensor): Mean value of face params. crop_info (dict): Face crop transforms. """ pred_param = body_predictions['pred_param'] pred_cam = body_predictions['pred_cam'] pred_raw = body_predictions['pred_raw'] pred_output = self.body_model(**pred_param) pred_keypoints3d = pred_output['joints'] pred_keypoints2d = weak_perspective_projection( pred_keypoints3d, scale=pred_cam[:, 0], translation=pred_cam[:, 1:3]) # concat ori_img full_images = [] for img_meta in img_metas: full_images.append(img_meta['ori_img'].to(device=pred_cam.device)) full_imgs = concat_images(full_images) face_joints = (pred_keypoints2d[:, self.face_idx] * 0.5 + 0.5) * (self.img_res - 1) face_points_to_crop = get_crop_info(face_joints, img_metas, self.scale_factor, self.img_res) face_center = face_points_to_crop['center'] face_orig_bbox_size = face_points_to_crop['orig_bbox_size'] face_inv_crop_transforms = face_points_to_crop['inv_crop_transforms'] face_cropper_out = self.face_cropper(full_imgs, face_center, face_orig_bbox_size) face_crops = face_cropper_out['images'] # face_points = face_cropper_out['sampling_grid'] face_crop_transform = face_cropper_out['transform'] all_face_imgs = [face_crops] all_face_imgs = torch.cat(all_face_imgs, dim=0) face_mean = self.build_face_mean(pred_param['global_orient'], pred_param['body_pose'], pred_param['betas'], pred_raw['raw_jaw_pose'], pred_param['expression'], batch_size=full_imgs.shape[0]) crop_info = dict(face_inv_crop_transforms=face_inv_crop_transforms, face_crop_transform=face_crop_transform) return all_face_imgs, face_mean, crop_info