|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
from torch import Tensor |
|
import smplx |
|
|
|
from .base import Datastruct, dataclass, Transform |
|
|
|
from .rots2rfeats import Rots2Rfeats |
|
from .rots2joints import Rots2Joints |
|
from .joints2jfeats import Joints2Jfeats |
|
|
|
|
|
class SMPLTransform(Transform): |
|
def __init__(self, rots2rfeats: Rots2Rfeats, |
|
rots2joints: Rots2Joints, |
|
joints2jfeats: Joints2Jfeats, |
|
**kwargs): |
|
self.rots2rfeats = rots2rfeats |
|
self.rots2joints = rots2joints |
|
self.joints2jfeats = joints2jfeats |
|
|
|
def Datastruct(self, **kwargs): |
|
return SMPLDatastruct(_rots2rfeats=self.rots2rfeats, |
|
_rots2joints=self.rots2joints, |
|
_joints2jfeats=self.joints2jfeats, |
|
transforms=self, |
|
**kwargs) |
|
|
|
def __repr__(self): |
|
return "SMPLTransform()" |
|
|
|
|
|
class RotIdentityTransform(Transform): |
|
def __init__(self, **kwargs): |
|
return |
|
|
|
def Datastruct(self, **kwargs): |
|
return RotTransDatastruct(**kwargs) |
|
|
|
def __repr__(self): |
|
return "RotIdentityTransform()" |
|
|
|
|
|
@dataclass |
|
class RotTransDatastruct(Datastruct): |
|
rots: Tensor |
|
trans: Tensor |
|
|
|
transforms: RotIdentityTransform = RotIdentityTransform() |
|
|
|
def __post_init__(self): |
|
self.datakeys = ["rots", "trans"] |
|
|
|
def __len__(self): |
|
return len(self.rots) |
|
|
|
|
|
@dataclass |
|
class SMPLDatastruct(Datastruct): |
|
transforms: SMPLTransform |
|
_rots2rfeats: Rots2Rfeats |
|
_rots2joints: Rots2Joints |
|
_joints2jfeats: Joints2Jfeats |
|
|
|
features: Optional[Tensor] = None |
|
rots_: Optional[RotTransDatastruct] = None |
|
rfeats_: Optional[Tensor] = None |
|
joints_: Optional[Tensor] = None |
|
jfeats_: Optional[Tensor] = None |
|
vertices_: Optional[Tensor] = None |
|
|
|
def __post_init__(self): |
|
self.datakeys = ['features', 'rots_', 'rfeats_', |
|
'joints_', 'jfeats_', 'vertices_'] |
|
|
|
if self.features is not None and self.rfeats_ is None: |
|
self.rfeats_ = self.features |
|
|
|
@property |
|
def rots(self): |
|
|
|
if self.rots_ is not None: |
|
return self.rots_ |
|
|
|
|
|
assert self.rfeats_ is not None |
|
|
|
self._rots2rfeats.to(self.rfeats.device) |
|
self.rots_ = self._rots2rfeats.inverse(self.rfeats) |
|
return self.rots_ |
|
|
|
@property |
|
def rfeats(self): |
|
|
|
if self.rfeats_ is not None: |
|
return self.rfeats_ |
|
|
|
|
|
assert self.rots_ is not None |
|
|
|
self._rots2rfeats.to(self.rots.device) |
|
self.rfeats_ = self._rots2rfeats(self.rots) |
|
return self.rfeats_ |
|
|
|
@property |
|
def joints(self): |
|
|
|
if self.joints_ is not None: |
|
return self.joints_ |
|
|
|
self._rots2joints.to(self.rots.device) |
|
self.joints_ = self._rots2joints(self.rots) |
|
return self.joints_ |
|
|
|
@property |
|
def jfeats(self): |
|
|
|
if self.jfeats_ is not None: |
|
return self.jfeats_ |
|
|
|
self._joints2jfeats.to(self.joints.device) |
|
self.jfeats_ = self._joints2jfeats(self.joints) |
|
return self.jfeats_ |
|
|
|
@property |
|
def vertices(self): |
|
|
|
if self.vertices_ is not None: |
|
return self.vertices_ |
|
|
|
self._rots2joints.to(self.rots.device) |
|
self.vertices_ = self._rots2joints(self.rots, jointstype="vertices") |
|
return self.vertices_ |
|
|
|
def __len__(self): |
|
return len(self.rfeats) |
|
|
|
|
|
def get_body_model(model_type, gender, batch_size, device='cpu', ext='pkl'): |
|
''' |
|
type: smpl, smplx smplh and others. Refer to smplx tutorial |
|
gender: male, female, neutral |
|
batch_size: an positive integar |
|
''' |
|
mtype = model_type.upper() |
|
if gender != 'neutral': |
|
if not isinstance(gender, str): |
|
gender = str(gender.astype(str)).upper() |
|
else: |
|
gender = gender.upper() |
|
else: |
|
gender = gender.upper() |
|
ext = 'npz' |
|
body_model_path = f'data/smpl_models/{model_type}/{mtype}_{gender}.{ext}' |
|
|
|
body_model = smplx.create(body_model_path, model_type=type, |
|
gender=gender, ext=ext, |
|
use_pca=False, |
|
num_pca_comps=12, |
|
create_global_orient=True, |
|
create_body_pose=True, |
|
create_betas=True, |
|
create_left_hand_pose=True, |
|
create_right_hand_pose=True, |
|
create_expression=True, |
|
create_jaw_pose=True, |
|
create_leye_pose=True, |
|
create_reye_pose=True, |
|
create_transl=True, |
|
batch_size=batch_size) |
|
|
|
if device == 'cuda': |
|
return body_model.cuda() |
|
else: |
|
return body_model |
|
|
|
|