File size: 13,849 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# Copyright (c) OpenMMLab. All rights reserved.

import os
from typing import Optional

import numpy as np
import torch
import torch.nn as nn

from detrsmpl.core.conventions.keypoints_mapping import convert_kps
from detrsmpl.utils.transforms import (
    aa_to_rotmat,
    make_homegeneous_rotmat_batch,
)


class STAR(nn.Module):

    NUM_BODY_JOINTS = 24

    def __init__(self,
                 model_path: str,
                 gender: str = 'neutral',
                 keypoint_src: str = 'star',
                 keypoint_dst: str = 'human_data',
                 keypoint_approximate: bool = False,
                 create_global_orient: bool = True,
                 global_orient: Optional[torch.Tensor] = None,
                 create_body_pose: bool = True,
                 body_pose: torch.Tensor = None,
                 num_betas: int = 10,
                 create_betas: bool = True,
                 betas: torch.Tensor = None,
                 create_transl: bool = True,
                 transl: torch.Tensor = None,
                 batch_size: int = 1,
                 dtype: torch.dtype = torch.float32) -> None:
        """STAR model constructor.

        Args:
            model_path: str
                The path to the folder or to the file where the model
                parameters are stored.
            gender: str, optional
                Which gender to load.
            keypoint_src: str
                Source convention of keypoints. This convention is used for
                keypoints obtained from joint regressors. Keypoints then
                undergo  conversion into keypoint_dst convention.
            keypoint_dst: destination convention of keypoints. This convention
                is used for keypoints in the output.
            keypoint_approximate: whether to use approximate matching in
                convention conversion for keypoints.
            create_global_orient: bool, optional
                Flag for creating a member variable for the global orientation
                of the body. (default = True)
            global_orient: torch.tensor, optional, Bx3
                The default value for the global orientation variable.
                (default = None)
            create_body_pose: bool, optional
                Flag for creating a member variable for the pose of the body.
                (default = True)
            body_pose: torch.tensor, optional, Bx(3*24)
                The default value for the body pose variable.
                (default = None)
            num_betas: int, optional
                Number of shape components to use
                (default = 10).
            create_betas: bool, optional
                Flag for creating a member variable for the shape space
                (default = True).
            betas: torch.tensor, optional, Bx10
                The default value for the shape member variable.
                (default = None)
            create_transl: bool, optional
                Flag for creating a member variable for the translation
                of the body. (default = True)
            transl: torch.tensor, optional, Bx3
                The default value for the transl variable.
                (default = None)
            batch_size: int, optional
                The batch size used for creating the member variables.
            dtype: torch.dtype, optional
                The data type for the created variables.
        """
        if gender not in ['male', 'female', 'neutral']:
            raise RuntimeError('Invalid gender! Should be one of '
                               '[\'male\', \'female\', or \'neutral\']!')
        self.gender = gender

        model_fname = 'STAR_{}.npz'.format(gender.upper())
        if not os.path.exists(model_path):
            raise RuntimeError('Path {} does not exist!'.format(model_path))
        elif os.path.isdir(model_path):
            star_path = os.path.join(model_path, model_fname)
        else:
            if os.path.split(model_path)[-1] != model_fname:
                raise RuntimeError(
                    f'Model filename ({model_fname}) and gender '
                    f'({gender}) are incompatible!')
            star_path = model_path

        super(STAR, self).__init__()

        self.keypoint_src = keypoint_src
        self.keypoint_dst = keypoint_dst
        self.keypoint_approximate = keypoint_approximate

        star_model = np.load(star_path, allow_pickle=True)
        J_regressor = star_model['J_regressor']
        self.num_betas = num_betas

        # Model sparse joints regressor, regresses joints location from a mesh
        self.register_buffer('J_regressor',
                             torch.tensor(J_regressor, dtype=torch.float))

        # Model skinning weights
        self.register_buffer(
            'weights', torch.tensor(star_model['weights'], dtype=torch.float))

        # Model pose corrective blend shapes
        self.register_buffer(
            'posedirs',
            torch.tensor(star_model['posedirs'].reshape((-1, 93)),
                         dtype=torch.float))

        # Mean Shape
        self.register_buffer(
            'v_template',
            torch.tensor(star_model['v_template'], dtype=torch.float))

        # Shape corrective blend shapes
        self.register_buffer(
            'shapedirs',
            torch.tensor(star_model['shapedirs'][:, :, :num_betas],
                         dtype=torch.float))

        # Mesh traingles
        self.register_buffer(
            'faces', torch.from_numpy(star_model['f'].astype(np.int64)))
        self.f = star_model['f']

        # Kinematic tree of the model
        self.register_buffer(
            'kintree_table',
            torch.from_numpy(star_model['kintree_table'].astype(np.int64)))

        id_to_col = {
            self.kintree_table[1, i].item(): i
            for i in range(self.kintree_table.shape[1])
        }
        self.register_buffer(
            'parent',
            torch.tensor([
                id_to_col[self.kintree_table[0, it].item()]
                for it in range(1, self.kintree_table.shape[1])
            ],
                         dtype=torch.int64))

        if create_global_orient:
            if global_orient is None:
                default_global_orient = torch.zeros([batch_size, 3],
                                                    dtype=dtype)
            else:
                if torch.is_tensor(global_orient):
                    default_global_orient = global_orient.clone().detach()
                else:
                    default_global_orient = torch.tensor(global_orient,
                                                         dtype=dtype)

            global_orient = nn.Parameter(default_global_orient,
                                         requires_grad=True)
            self.register_parameter('global_orient', global_orient)

        if create_body_pose:
            if body_pose is None:
                default_body_pose = torch.zeros(
                    [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)
            else:
                if torch.is_tensor(body_pose):
                    default_body_pose = body_pose.clone().detach()
                else:
                    default_body_pose = torch.tensor(body_pose, dtype=dtype)
            self.register_parameter(
                'body_pose', nn.Parameter(default_body_pose,
                                          requires_grad=True))

        if create_betas:
            if betas is None:
                default_betas = torch.zeros([batch_size, self.num_betas],
                                            dtype=dtype)
            else:
                if torch.is_tensor(betas):
                    default_betas = betas.clone().detach()
                else:
                    default_betas = torch.tensor(betas, dtype=dtype)

            self.register_parameter(
                'betas', nn.Parameter(default_betas, requires_grad=True))

        if create_transl:
            if transl is None:
                default_transl = torch.zeros([batch_size, 3],
                                             dtype=dtype,
                                             requires_grad=True)
            else:
                default_transl = torch.tensor(transl, dtype=dtype)
            self.register_parameter(
                'transl', nn.Parameter(default_transl, requires_grad=True))

        self.verts = None
        self.J = None
        self.R = None

    def forward(self,
                global_orient: Optional[torch.Tensor] = None,
                body_pose: Optional[torch.Tensor] = None,
                betas: Optional[torch.Tensor] = None,
                transl: Optional[torch.Tensor] = None,
                return_verts: bool = True,
                return_full_pose: bool = True) -> torch.Tensor:
        """Forward pass for the STAR model.

        Args:
            global_orient: torch.tensor, optional, shape Bx3
                Global orientation (rotation) of the body. If given, ignore the
                member variable and use it as the global rotation of the body.
                Useful if someone wishes to predicts this with an external
                model. (default=None)
            body_pose: torch.Tensor, shape Bx(J*3)
                Pose parameters for the STAR model. It should be a tensor that
                contains joint rotations in axis-angle format. If given, ignore
                the member variable and use it as the body parameters.
                (default=None)
            betas: torch.Tensor, shape Bx10
                Shape parameters for the STAR model. If given, ignore the
                member variable and use it as shape parameters. (default=None)
            transl: torch.Tensor, shape Bx3
                Translation vector for the STAR model. If given, ignore the
                member variable and use it as the translation of the body.
                (default=None)
        Returns:
            output: Contains output parameters and attributes corresponding
            to other body models.
        """
        global_orient = (global_orient
                         if global_orient is not None else self.global_orient)
        body_pose = body_pose if body_pose is not None else self.body_pose
        betas = betas if betas is not None else self.betas
        apply_transl = transl is not None or hasattr(self, 'transl')
        if transl is None and hasattr(self, 'transl'):
            transl = self.transl

        batch_size = body_pose.shape[0]
        v_template = self.v_template[None, :]
        shapedirs = self.shapedirs.view(-1, self.num_betas)[None, :].expand(
            batch_size, -1, -1)
        beta = betas[:, :, None]
        v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
        J = torch.einsum('bik,ji->bjk', [v_shaped, self.J_regressor])

        pose_quat = self.normalize_quaternion(body_pose.view(-1, 3)).view(
            batch_size, -1)
        pose_feat = torch.cat((pose_quat[:, 4:], beta[:, 1]), 1)

        R = aa_to_rotmat(body_pose.view(-1, 3)).view(batch_size, 24, 3, 3)
        R = R.view(batch_size, 24, 3, 3)

        posedirs = self.posedirs[None, :].expand(batch_size, -1, -1)
        v_posed = v_shaped + torch.matmul(
            posedirs, pose_feat[:, :, None]).view(-1, 6890, 3)

        root_transform = make_homegeneous_rotmat_batch(
            torch.cat((R[:, 0], J[:, 0][:, :, None]), 2))
        results = [root_transform]
        for i in range(0, self.parent.shape[0]):
            transform_i = make_homegeneous_rotmat_batch(
                torch.cat((R[:, i + 1], J[:, i + 1][:, :, None] -
                           J[:, self.parent[i]][:, :, None]), 2))
            curr_res = torch.matmul(results[self.parent[i]], transform_i)
            results.append(curr_res)
        results = torch.stack(results, dim=1)
        posed_joints = results[:, :, :3, 3]

        if apply_transl:
            posed_joints += transl[:, None, :]
            v_posed += transl[:, None, :]

        joints, joint_mask = convert_kps(posed_joints,
                                         src=self.keypoint_src,
                                         dst=self.keypoint_dst,
                                         approximate=self.keypoint_approximate)

        joint_mask = torch.tensor(joint_mask,
                                  dtype=torch.uint8,
                                  device=joints.device)
        joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)

        output = dict(global_orient=global_orient,
                      body_pose=body_pose,
                      joints=posed_joints,
                      joint_mask=joint_mask,
                      keypoints=torch.cat([joints, joint_mask[:, :, None]],
                                          dim=-1),
                      betas=beta)

        if return_verts:
            output['vertices'] = v_posed
        if return_full_pose:
            output['full_pose'] = torch.cat([global_orient, body_pose], dim=1)

        return output

    @classmethod
    def normalize_quaternion(self, theta: torch.Tensor) -> torch.Tensor:
        """Computes a normalized quaternion ([0,0,0,0] when the body is in rest
        pose) given joint angles.

        Args:
            theta (torch.Tensor): A tensor of joints axis angles,
                batch size x number of joints x 3

        Returns:
            quat (torch.Tensor)
        """
        l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
        angle = torch.unsqueeze(l1norm, -1)
        normalized = torch.div(theta, angle)
        angle = angle * 0.5
        v_cos = torch.cos(angle)
        v_sin = torch.sin(angle)
        quat = torch.cat([v_sin * normalized, v_cos - 1], dim=1)
        return quat