ttxskk
update
d7e58f0
raw
history blame
7.14 kB
import numpy as np
import torch
from smplx import FLAME as _FLAME
from smplx import FLAMELayer as _FLAMELayer
from detrsmpl.core.conventions.keypoints_mapping import (
convert_kps,
get_keypoint_num,
)
class FLAME(_FLAME):
"""Extension of the official FLAME implementation."""
head_pose_keys = {'global_orient', 'jaw_pose'}
full_pose_keys = {
'global_orient', 'neck_pose', 'jaw_pose', 'leye_pose', 'reye_pose'
}
NUM_VERTS = 5023
NUM_FACES = 9976
def __init__(self,
*args,
keypoint_src: str = 'flame',
keypoint_dst: str = 'human_data',
keypoint_approximate: bool = False,
**kwargs):
"""
Args:
*args: extra arguments for FLAME initialization.
keypoint_src: source convention of keypoints. This convention
is used for keypoints obtained from joint regressors.
Keypoints then undergo conversion into keypoint_dst
convention.
keypoint_dst: destination convention of keypoints. This convention
is used for keypoints in the output.
keypoint_approximate: whether to use approximate matching in
convention conversion for keypoints.
**kwargs: extra keyword arguments for FLAME initialization.
Returns:
None
"""
super(FLAME, self).__init__(*args, **kwargs)
self.keypoint_src = keypoint_src
self.keypoint_dst = keypoint_dst
self.keypoint_approximate = keypoint_approximate
self.num_verts = self.get_num_verts()
self.num_faces = self.get_num_faces()
self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
def forward(self,
*args,
return_verts: bool = True,
return_full_pose: bool = False,
**kwargs) -> dict:
"""Forward function.
Args:
*args: extra arguments for FLAME
return_verts: whether to return vertices
return_full_pose: whether to return full pose parameters
**kwargs: extra arguments for FLAME
Returns:
output: contains output parameters and attributes
"""
flame_output = super(FLAME, self).forward(*args, **kwargs)
joints = flame_output.joints
joints, joint_mask = convert_kps(joints,
src=self.keypoint_src,
dst=self.keypoint_dst,
approximate=self.keypoint_approximate)
if isinstance(joint_mask, np.ndarray):
joint_mask = torch.tensor(joint_mask,
dtype=torch.uint8,
device=joints.device)
batch_size = joints.shape[0]
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
output = dict(global_orient=flame_output.global_orient,
neck_pose=flame_output.neck_pose,
jaw_pose=flame_output.jaw_pose,
joints=joints,
joint_mask=joint_mask,
keypoints=torch.cat([joints, joint_mask[:, :, None]],
dim=-1),
betas=flame_output.betas,
expression=flame_output.expression)
if return_verts:
output['vertices'] = flame_output.vertices
if return_full_pose:
output['full_pose'] = flame_output.full_pose
return output
class FLAMELayer(_FLAMELayer):
"""Extension of the official FLAME implementation."""
head_pose_keys = {'global_orient', 'jaw_pose'}
full_pose_keys = {
'global_orient', 'neck_pose', 'jaw_pose', 'leye_pose', 'reye_pose'
}
NUM_VERTS = 5023
NUM_FACES = 9976
def __init__(self,
*args,
keypoint_src: str = 'flame',
keypoint_dst: str = 'human_data',
keypoint_approximate: bool = False,
**kwargs):
"""
Args:
*args: extra arguments for FLAME initialization.
keypoint_src: source convention of keypoints. This convention
is used for keypoints obtained from joint regressors.
Keypoints then undergo conversion into keypoint_dst
convention.
keypoint_dst: destination convention of keypoints. This convention
is used for keypoints in the output.
keypoint_approximate: whether to use approximate matching in
convention conversion for keypoints.
**kwargs: extra keyword arguments for FLAME initialization.
Returns:
None
"""
super(FLAMELayer, self).__init__(*args, **kwargs)
self.keypoint_src = keypoint_src
self.keypoint_dst = keypoint_dst
self.keypoint_approximate = keypoint_approximate
self.num_verts = self.get_num_verts()
self.num_faces = self.get_num_faces()
self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
def forward(self,
*args,
return_verts: bool = True,
return_full_pose: bool = False,
**kwargs) -> dict:
"""Forward function.
Args:
*args: extra arguments for FLAME
return_verts: whether to return vertices
return_full_pose: whether to return full pose parameters
**kwargs: extra arguments for FLAME
Returns:
output: contains output parameters and attributes
"""
flame_output = super(FLAMELayer, self).forward(*args, **kwargs)
joints = flame_output.joints
joints, joint_mask = convert_kps(joints,
src=self.keypoint_src,
dst=self.keypoint_dst,
approximate=self.keypoint_approximate)
if isinstance(joint_mask, np.ndarray):
joint_mask = torch.tensor(joint_mask,
dtype=torch.uint8,
device=joints.device)
batch_size = joints.shape[0]
joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
output = dict(global_orient=flame_output.global_orient,
neck_pose=flame_output.neck_pose,
jaw_pose=flame_output.jaw_pose,
joints=joints,
joint_mask=joint_mask,
keypoints=torch.cat([joints, joint_mask[:, :, None]],
dim=-1),
betas=flame_output.betas,
expression=flame_output.expression)
if return_verts:
output['vertices'] = flame_output.vertices
if return_full_pose:
output['full_pose'] = flame_output.full_pose
return output