File size: 5,831 Bytes
4409449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
from typing import Optional
from torch import Tensor
import smplx
from .base import Datastruct, dataclass, Transform
from .rots2rfeats import Rots2Rfeats
from .rots2joints import Rots2Joints
from .joints2jfeats import Joints2Jfeats
class SMPLTransform(Transform):
def __init__(self, rots2rfeats: Rots2Rfeats,
rots2joints: Rots2Joints,
joints2jfeats: Joints2Jfeats,
**kwargs):
self.rots2rfeats = rots2rfeats
self.rots2joints = rots2joints
self.joints2jfeats = joints2jfeats
def Datastruct(self, **kwargs):
return SMPLDatastruct(_rots2rfeats=self.rots2rfeats,
_rots2joints=self.rots2joints,
_joints2jfeats=self.joints2jfeats,
transforms=self,
**kwargs)
def __repr__(self):
return "SMPLTransform()"
class RotIdentityTransform(Transform):
def __init__(self, **kwargs):
return
def Datastruct(self, **kwargs):
return RotTransDatastruct(**kwargs)
def __repr__(self):
return "RotIdentityTransform()"
@dataclass
class RotTransDatastruct(Datastruct):
rots: Tensor
trans: Tensor
transforms: RotIdentityTransform = RotIdentityTransform()
def __post_init__(self):
self.datakeys = ["rots", "trans"]
def __len__(self):
return len(self.rots)
@dataclass
class SMPLDatastruct(Datastruct):
transforms: SMPLTransform
_rots2rfeats: Rots2Rfeats
_rots2joints: Rots2Joints
_joints2jfeats: Joints2Jfeats
features: Optional[Tensor] = None
rots_: Optional[RotTransDatastruct] = None
rfeats_: Optional[Tensor] = None
joints_: Optional[Tensor] = None
jfeats_: Optional[Tensor] = None
vertices_: Optional[Tensor] = None
def __post_init__(self):
self.datakeys = ['features', 'rots_', 'rfeats_',
'joints_', 'jfeats_', 'vertices_']
# starting point
if self.features is not None and self.rfeats_ is None:
self.rfeats_ = self.features
@property
def rots(self):
# Cached value
if self.rots_ is not None:
return self.rots_
# self.rfeats_ should be defined
assert self.rfeats_ is not None
self._rots2rfeats.to(self.rfeats.device)
self.rots_ = self._rots2rfeats.inverse(self.rfeats)
return self.rots_
@property
def rfeats(self):
# Cached value
if self.rfeats_ is not None:
return self.rfeats_
# self.rots_ should be defined
assert self.rots_ is not None
self._rots2rfeats.to(self.rots.device)
self.rfeats_ = self._rots2rfeats(self.rots)
return self.rfeats_
@property
def joints(self):
# Cached value
if self.joints_ is not None:
return self.joints_
self._rots2joints.to(self.rots.device)
self.joints_ = self._rots2joints(self.rots)
return self.joints_
@property
def jfeats(self):
# Cached value
if self.jfeats_ is not None:
return self.jfeats_
self._joints2jfeats.to(self.joints.device)
self.jfeats_ = self._joints2jfeats(self.joints)
return self.jfeats_
@property
def vertices(self):
# Cached value
if self.vertices_ is not None:
return self.vertices_
self._rots2joints.to(self.rots.device)
self.vertices_ = self._rots2joints(self.rots, jointstype="vertices")
return self.vertices_
def __len__(self):
return len(self.rfeats)
def get_body_model(model_type, gender, batch_size, device='cpu', ext='pkl'):
'''
type: smpl, smplx smplh and others. Refer to smplx tutorial
gender: male, female, neutral
batch_size: an positive integar
'''
mtype = model_type.upper()
if gender != 'neutral':
if not isinstance(gender, str):
gender = str(gender.astype(str)).upper()
else:
gender = gender.upper()
else:
gender = gender.upper()
ext = 'npz'
body_model_path = f'data/smpl_models/{model_type}/{mtype}_{gender}.{ext}'
body_model = smplx.create(body_model_path, model_type=type,
gender=gender, ext=ext,
use_pca=False,
num_pca_comps=12,
create_global_orient=True,
create_body_pose=True,
create_betas=True,
create_left_hand_pose=True,
create_right_hand_pose=True,
create_expression=True,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=True,
batch_size=batch_size)
if device == 'cuda':
return body_model.cuda()
else:
return body_model
|