Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from detrsmpl.core.conventions.keypoints_mapping import convert_kps | |
from detrsmpl.utils.transforms import ( | |
aa_to_rotmat, | |
make_homegeneous_rotmat_batch, | |
) | |
class STAR(nn.Module): | |
NUM_BODY_JOINTS = 24 | |
def __init__(self, | |
model_path: str, | |
gender: str = 'neutral', | |
keypoint_src: str = 'star', | |
keypoint_dst: str = 'human_data', | |
keypoint_approximate: bool = False, | |
create_global_orient: bool = True, | |
global_orient: Optional[torch.Tensor] = None, | |
create_body_pose: bool = True, | |
body_pose: torch.Tensor = None, | |
num_betas: int = 10, | |
create_betas: bool = True, | |
betas: torch.Tensor = None, | |
create_transl: bool = True, | |
transl: torch.Tensor = None, | |
batch_size: int = 1, | |
dtype: torch.dtype = torch.float32) -> None: | |
"""STAR model constructor. | |
Args: | |
model_path: str | |
The path to the folder or to the file where the model | |
parameters are stored. | |
gender: str, optional | |
Which gender to load. | |
keypoint_src: str | |
Source convention of keypoints. This convention is used for | |
keypoints obtained from joint regressors. Keypoints then | |
undergo conversion into keypoint_dst convention. | |
keypoint_dst: destination convention of keypoints. This convention | |
is used for keypoints in the output. | |
keypoint_approximate: whether to use approximate matching in | |
convention conversion for keypoints. | |
create_global_orient: bool, optional | |
Flag for creating a member variable for the global orientation | |
of the body. (default = True) | |
global_orient: torch.tensor, optional, Bx3 | |
The default value for the global orientation variable. | |
(default = None) | |
create_body_pose: bool, optional | |
Flag for creating a member variable for the pose of the body. | |
(default = True) | |
body_pose: torch.tensor, optional, Bx(3*24) | |
The default value for the body pose variable. | |
(default = None) | |
num_betas: int, optional | |
Number of shape components to use | |
(default = 10). | |
create_betas: bool, optional | |
Flag for creating a member variable for the shape space | |
(default = True). | |
betas: torch.tensor, optional, Bx10 | |
The default value for the shape member variable. | |
(default = None) | |
create_transl: bool, optional | |
Flag for creating a member variable for the translation | |
of the body. (default = True) | |
transl: torch.tensor, optional, Bx3 | |
The default value for the transl variable. | |
(default = None) | |
batch_size: int, optional | |
The batch size used for creating the member variables. | |
dtype: torch.dtype, optional | |
The data type for the created variables. | |
""" | |
if gender not in ['male', 'female', 'neutral']: | |
raise RuntimeError('Invalid gender! Should be one of ' | |
'[\'male\', \'female\', or \'neutral\']!') | |
self.gender = gender | |
model_fname = 'STAR_{}.npz'.format(gender.upper()) | |
if not os.path.exists(model_path): | |
raise RuntimeError('Path {} does not exist!'.format(model_path)) | |
elif os.path.isdir(model_path): | |
star_path = os.path.join(model_path, model_fname) | |
else: | |
if os.path.split(model_path)[-1] != model_fname: | |
raise RuntimeError( | |
f'Model filename ({model_fname}) and gender ' | |
f'({gender}) are incompatible!') | |
star_path = model_path | |
super(STAR, self).__init__() | |
self.keypoint_src = keypoint_src | |
self.keypoint_dst = keypoint_dst | |
self.keypoint_approximate = keypoint_approximate | |
star_model = np.load(star_path, allow_pickle=True) | |
J_regressor = star_model['J_regressor'] | |
self.num_betas = num_betas | |
# Model sparse joints regressor, regresses joints location from a mesh | |
self.register_buffer('J_regressor', | |
torch.tensor(J_regressor, dtype=torch.float)) | |
# Model skinning weights | |
self.register_buffer( | |
'weights', torch.tensor(star_model['weights'], dtype=torch.float)) | |
# Model pose corrective blend shapes | |
self.register_buffer( | |
'posedirs', | |
torch.tensor(star_model['posedirs'].reshape((-1, 93)), | |
dtype=torch.float)) | |
# Mean Shape | |
self.register_buffer( | |
'v_template', | |
torch.tensor(star_model['v_template'], dtype=torch.float)) | |
# Shape corrective blend shapes | |
self.register_buffer( | |
'shapedirs', | |
torch.tensor(star_model['shapedirs'][:, :, :num_betas], | |
dtype=torch.float)) | |
# Mesh traingles | |
self.register_buffer( | |
'faces', torch.from_numpy(star_model['f'].astype(np.int64))) | |
self.f = star_model['f'] | |
# Kinematic tree of the model | |
self.register_buffer( | |
'kintree_table', | |
torch.from_numpy(star_model['kintree_table'].astype(np.int64))) | |
id_to_col = { | |
self.kintree_table[1, i].item(): i | |
for i in range(self.kintree_table.shape[1]) | |
} | |
self.register_buffer( | |
'parent', | |
torch.tensor([ | |
id_to_col[self.kintree_table[0, it].item()] | |
for it in range(1, self.kintree_table.shape[1]) | |
], | |
dtype=torch.int64)) | |
if create_global_orient: | |
if global_orient is None: | |
default_global_orient = torch.zeros([batch_size, 3], | |
dtype=dtype) | |
else: | |
if torch.is_tensor(global_orient): | |
default_global_orient = global_orient.clone().detach() | |
else: | |
default_global_orient = torch.tensor(global_orient, | |
dtype=dtype) | |
global_orient = nn.Parameter(default_global_orient, | |
requires_grad=True) | |
self.register_parameter('global_orient', global_orient) | |
if create_body_pose: | |
if body_pose is None: | |
default_body_pose = torch.zeros( | |
[batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) | |
else: | |
if torch.is_tensor(body_pose): | |
default_body_pose = body_pose.clone().detach() | |
else: | |
default_body_pose = torch.tensor(body_pose, dtype=dtype) | |
self.register_parameter( | |
'body_pose', nn.Parameter(default_body_pose, | |
requires_grad=True)) | |
if create_betas: | |
if betas is None: | |
default_betas = torch.zeros([batch_size, self.num_betas], | |
dtype=dtype) | |
else: | |
if torch.is_tensor(betas): | |
default_betas = betas.clone().detach() | |
else: | |
default_betas = torch.tensor(betas, dtype=dtype) | |
self.register_parameter( | |
'betas', nn.Parameter(default_betas, requires_grad=True)) | |
if create_transl: | |
if transl is None: | |
default_transl = torch.zeros([batch_size, 3], | |
dtype=dtype, | |
requires_grad=True) | |
else: | |
default_transl = torch.tensor(transl, dtype=dtype) | |
self.register_parameter( | |
'transl', nn.Parameter(default_transl, requires_grad=True)) | |
self.verts = None | |
self.J = None | |
self.R = None | |
def forward(self, | |
global_orient: Optional[torch.Tensor] = None, | |
body_pose: Optional[torch.Tensor] = None, | |
betas: Optional[torch.Tensor] = None, | |
transl: Optional[torch.Tensor] = None, | |
return_verts: bool = True, | |
return_full_pose: bool = True) -> torch.Tensor: | |
"""Forward pass for the STAR model. | |
Args: | |
global_orient: torch.tensor, optional, shape Bx3 | |
Global orientation (rotation) of the body. If given, ignore the | |
member variable and use it as the global rotation of the body. | |
Useful if someone wishes to predicts this with an external | |
model. (default=None) | |
body_pose: torch.Tensor, shape Bx(J*3) | |
Pose parameters for the STAR model. It should be a tensor that | |
contains joint rotations in axis-angle format. If given, ignore | |
the member variable and use it as the body parameters. | |
(default=None) | |
betas: torch.Tensor, shape Bx10 | |
Shape parameters for the STAR model. If given, ignore the | |
member variable and use it as shape parameters. (default=None) | |
transl: torch.Tensor, shape Bx3 | |
Translation vector for the STAR model. If given, ignore the | |
member variable and use it as the translation of the body. | |
(default=None) | |
Returns: | |
output: Contains output parameters and attributes corresponding | |
to other body models. | |
""" | |
global_orient = (global_orient | |
if global_orient is not None else self.global_orient) | |
body_pose = body_pose if body_pose is not None else self.body_pose | |
betas = betas if betas is not None else self.betas | |
apply_transl = transl is not None or hasattr(self, 'transl') | |
if transl is None and hasattr(self, 'transl'): | |
transl = self.transl | |
batch_size = body_pose.shape[0] | |
v_template = self.v_template[None, :] | |
shapedirs = self.shapedirs.view(-1, self.num_betas)[None, :].expand( | |
batch_size, -1, -1) | |
beta = betas[:, :, None] | |
v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template | |
J = torch.einsum('bik,ji->bjk', [v_shaped, self.J_regressor]) | |
pose_quat = self.normalize_quaternion(body_pose.view(-1, 3)).view( | |
batch_size, -1) | |
pose_feat = torch.cat((pose_quat[:, 4:], beta[:, 1]), 1) | |
R = aa_to_rotmat(body_pose.view(-1, 3)).view(batch_size, 24, 3, 3) | |
R = R.view(batch_size, 24, 3, 3) | |
posedirs = self.posedirs[None, :].expand(batch_size, -1, -1) | |
v_posed = v_shaped + torch.matmul( | |
posedirs, pose_feat[:, :, None]).view(-1, 6890, 3) | |
root_transform = make_homegeneous_rotmat_batch( | |
torch.cat((R[:, 0], J[:, 0][:, :, None]), 2)) | |
results = [root_transform] | |
for i in range(0, self.parent.shape[0]): | |
transform_i = make_homegeneous_rotmat_batch( | |
torch.cat((R[:, i + 1], J[:, i + 1][:, :, None] - | |
J[:, self.parent[i]][:, :, None]), 2)) | |
curr_res = torch.matmul(results[self.parent[i]], transform_i) | |
results.append(curr_res) | |
results = torch.stack(results, dim=1) | |
posed_joints = results[:, :, :3, 3] | |
if apply_transl: | |
posed_joints += transl[:, None, :] | |
v_posed += transl[:, None, :] | |
joints, joint_mask = convert_kps(posed_joints, | |
src=self.keypoint_src, | |
dst=self.keypoint_dst, | |
approximate=self.keypoint_approximate) | |
joint_mask = torch.tensor(joint_mask, | |
dtype=torch.uint8, | |
device=joints.device) | |
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) | |
output = dict(global_orient=global_orient, | |
body_pose=body_pose, | |
joints=posed_joints, | |
joint_mask=joint_mask, | |
keypoints=torch.cat([joints, joint_mask[:, :, None]], | |
dim=-1), | |
betas=beta) | |
if return_verts: | |
output['vertices'] = v_posed | |
if return_full_pose: | |
output['full_pose'] = torch.cat([global_orient, body_pose], dim=1) | |
return output | |
def normalize_quaternion(self, theta: torch.Tensor) -> torch.Tensor: | |
"""Computes a normalized quaternion ([0,0,0,0] when the body is in rest | |
pose) given joint angles. | |
Args: | |
theta (torch.Tensor): A tensor of joints axis angles, | |
batch size x number of joints x 3 | |
Returns: | |
quat (torch.Tensor) | |
""" | |
l1norm = torch.norm(theta + 1e-8, p=2, dim=1) | |
angle = torch.unsqueeze(l1norm, -1) | |
normalized = torch.div(theta, angle) | |
angle = angle * 0.5 | |
v_cos = torch.cos(angle) | |
v_sin = torch.sin(angle) | |
quat = torch.cat([v_sin * normalized, v_cos - 1], dim=1) | |
return quat | |