File size: 17,167 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
"""This script is based on the release codes:

"HybrIK: A Hybrid Analytical-Neural Inverse Kinematics Solution for 3D Human
Pose and Shape Estimation. CVPR 2021"
(https://github.com/Jeff-sjtu/HybrIK).
"""

from __future__ import absolute_import, division, print_function

import torch

from detrsmpl.utils.transforms import aa_to_rotmat


def batch_inverse_kinematics_transform(pose_skeleton,
                                       global_orient,
                                       phis,
                                       rest_pose,
                                       children,
                                       parents,
                                       dtype=torch.float32,
                                       train=False,
                                       leaf_thetas=None):
    """Applies inverse kinematics transform to joints in a batch.

    Args:
        pose_skeleton (torch.tensor):
            Locations of estimated pose skeleton with shape (Bx29x3)
        global_orient (torch.tensor|none):
            Tensor of global rotation matrices with shape (Bx1x3x3)
        phis (torch.tensor):
            Rotation on bone axis parameters with shape (Bx23x2)
        rest_pose (torch.tensor):
            Locations of rest (Template) pose with shape (Bx29x3)
        children (List[int]): list of indexes of kinematic children with len 29
        parents (List[int]): list of indexes of kinematic parents with len 29
        dtype (torch.dtype, optional):
            Data type of the created tensors. Default: torch.float32
        train (bool):
            Store True in train mode. Default: False
        leaf_thetas (torch.tensor, optional):
            Rotation matrixes for 5 leaf joints (Bx5x3x3). Default: None


    Returns:
        rot_mats (torch.tensor):
            Rotation matrics of all joints with shape (Bx29x3x3)
        rotate_rest_pose (torch.tensor):
            Locations of rotated rest/ template pose with shape (Bx29x3)
    """
    batch_size = pose_skeleton.shape[0]
    device = pose_skeleton.device

    rel_rest_pose = rest_pose.clone()
    # vec_t_k = t_k - t_pa(k)
    rel_rest_pose[:, 1:] -= rest_pose[:, parents[1:]].clone()
    rel_rest_pose = torch.unsqueeze(rel_rest_pose, dim=-1)

    # rotate the T pose
    rotate_rest_pose = torch.zeros_like(rel_rest_pose)
    # set up the root
    rotate_rest_pose[:, 0] = rel_rest_pose[:, 0]

    rel_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1).detach()
    rel_pose_skeleton[:, 1:] -= rel_pose_skeleton[:, parents[1:]].clone()
    rel_pose_skeleton[:, 0] = rel_rest_pose[:, 0]

    # the predicted final pose
    final_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1)
    if train:
        final_pose_skeleton[:, 1:] -= \
            final_pose_skeleton[:, parents[1:]].clone()
        final_pose_skeleton[:, 0] = rel_rest_pose[:, 0]
    else:
        final_pose_skeleton += \
            rel_rest_pose[:, 0:1] - final_pose_skeleton[:, 0:1]

    rel_rest_pose = rel_rest_pose
    rel_pose_skeleton = rel_pose_skeleton
    final_pose_skeleton = final_pose_skeleton
    rotate_rest_pose = rotate_rest_pose

    assert phis.dim() == 3
    phis = phis / (torch.norm(phis, dim=2, keepdim=True) + 1e-8)

    if train:
        global_orient_mat = batch_get_pelvis_orient(rel_pose_skeleton.clone(),
                                                    rel_rest_pose.clone(),
                                                    parents, children, dtype)
    else:
        global_orient_mat = batch_get_pelvis_orient_svd(
            rel_pose_skeleton.clone(), rel_rest_pose.clone(), parents,
            children, dtype)

    rot_mat_chain = [global_orient_mat]
    rot_mat_local = [global_orient_mat]
    # leaf nodes rot_mats
    if leaf_thetas is not None:
        leaf_cnt = 0
        leaf_rot_mats = leaf_thetas.view([batch_size, 5, 3, 3])

    for i in range(1, parents.shape[0]):
        if children[i] == -1:
            # leaf nodes
            if leaf_thetas is not None:
                rot_mat = leaf_rot_mats[:, leaf_cnt, :, :]
                leaf_cnt += 1

                rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[
                    i]] + torch.matmul(rot_mat_chain[parents[i]],
                                       rel_rest_pose[:, i])

                rot_mat_chain.append(
                    torch.matmul(rot_mat_chain[parents[i]], rot_mat))
                rot_mat_local.append(rot_mat)
        elif children[i] == -3:
            # three children
            rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[i]] + \
                torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i])

            spine_child = []
            for c in range(1, parents.shape[0]):
                if parents[c] == i and c not in spine_child:
                    spine_child.append(c)

            # original
            spine_child = []
            for c in range(1, parents.shape[0]):
                if parents[c] == i and c not in spine_child:
                    spine_child.append(c)

            children_final_loc = []
            children_rest_loc = []
            for c in spine_child:
                temp = final_pose_skeleton[:, c] - rotate_rest_pose[:, i]
                children_final_loc.append(temp)

                children_rest_loc.append(rel_rest_pose[:, c].clone())

            rot_mat = batch_get_3children_orient_svd(children_final_loc,
                                                     children_rest_loc,
                                                     rot_mat_chain[parents[i]],
                                                     spine_child, dtype)

            rot_mat_chain.append(
                torch.matmul(rot_mat_chain[parents[i]], rot_mat))
            rot_mat_local.append(rot_mat)
        else:
            # Naive Hybrik
            if train:
                # i: the index of k-th joint
                child_rest_loc = rel_rest_pose[:, i]
                child_final_loc = final_pose_skeleton[:, i]

            # q_pa(k) = q_pa^2(k) + R_pa(k)(t_pa(k) - t_pa^2(k))
            rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[i]] + \
                torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i])
            # Adaptive HybrIK
            if not train:
                # children[i]: the index of k-th joint
                child_rest_loc = rel_rest_pose[:, children[i]]
                child_final_loc = final_pose_skeleton[:, children[
                    i]] - rotate_rest_pose[:, i]

                orig_vec = rel_pose_skeleton[:, children[i]]
                template_vec = rel_rest_pose[:, children[i]]
                norm_t = torch.norm(template_vec, dim=1, keepdim=True)
                orig_vec = orig_vec * norm_t / torch.norm(
                    orig_vec, dim=1, keepdim=True)

                diff = torch.norm(child_final_loc - orig_vec,
                                  dim=1,
                                  keepdim=True)
                big_diff_idx = torch.where(diff > 15 / 1000)[0]

                child_final_loc[big_diff_idx] = orig_vec[big_diff_idx]

            # train: vec_p_k = R_pa(k).T * (p_k - p_pa(k))
            # test: vec_p_k = R_pa(k).T * (p_k - q_pa(k))
            child_final_loc = torch.matmul(
                rot_mat_chain[parents[i]].transpose(1, 2), child_final_loc)

            # (B, 1, 1)
            child_final_norm = torch.norm(child_final_loc, dim=1, keepdim=True)
            child_rest_norm = torch.norm(child_rest_loc, dim=1, keepdim=True)

            # vec_n
            axis = torch.cross(child_rest_loc, child_final_loc, dim=1)
            axis_norm = torch.norm(axis, dim=1, keepdim=True)

            # (B, 1, 1)
            cos = torch.sum(
                child_rest_loc * child_final_loc, dim=1,
                keepdim=True) / (child_rest_norm * child_final_norm + 1e-8)
            sin = axis_norm / (child_rest_norm * child_final_norm + 1e-8)

            # (B, 3, 1)
            axis = axis / (axis_norm + 1e-8)

            # Convert location revolve to rot_mat by rodrigues
            # (B, 1, 1)
            rx, ry, rz = torch.split(axis, 1, dim=1)
            zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device)

            K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros],
                          dim=1).view((batch_size, 3, 3))
            ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
            rot_mat_loc = ident + sin * K + (1 - cos) * torch.bmm(K, K)

            # Convert spin to rot_mat
            # (B, 3, 1)
            spin_axis = child_rest_loc / child_rest_norm
            # (B, 1, 1)
            rx, ry, rz = torch.split(spin_axis, 1, dim=1)
            zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device)
            K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros],
                          dim=1).view((batch_size, 3, 3))
            ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
            # (B, 1, 1)
            cos, sin = torch.split(phis[:, i - 1], 1, dim=1)
            cos = torch.unsqueeze(cos, dim=2)
            sin = torch.unsqueeze(sin, dim=2)
            rot_mat_spin = ident + sin * K + (1 - cos) * torch.bmm(K, K)
            rot_mat = torch.matmul(rot_mat_loc, rot_mat_spin)

            rot_mat_chain.append(
                torch.matmul(rot_mat_chain[parents[i]], rot_mat))
            rot_mat_local.append(rot_mat)

    # (B, K + 1, 3, 3)
    rot_mats = torch.stack(rot_mat_local, dim=1)

    return rot_mats, rotate_rest_pose.squeeze(-1)


def batch_get_pelvis_orient_svd(rel_pose_skeleton, rel_rest_pose, parents,
                                children, dtype):
    """Get pelvis orientation svd for batch data.

    Args:
        rel_pose_skeleton (torch.tensor):
            Locations of root-normalized pose skeleton with shape (Bx29x3)
        rel_rest_pose (torch.tensor):
            Locations of rest/ template pose with shape (Bx29x3)
        parents (List[int]): list of indexes of kinematic parents with len 29
        children (List[int]): list of indexes of kinematic children with len 29
        dtype (torch.dtype, optional):
            Data type of the created tensors, the default is torch.float32

    Returns:
        rot_mat (torch.tensor):
            Rotation matrix of pelvis with shape (Bx3x3)
    """
    pelvis_child = [int(children[0])]
    for i in range(1, parents.shape[0]):
        if parents[i] == 0 and i not in pelvis_child:
            pelvis_child.append(i)

    rest_mat = []
    target_mat = []
    for child in pelvis_child:
        rest_mat.append(rel_rest_pose[:, child].clone())
        target_mat.append(rel_pose_skeleton[:, child].clone())

    rest_mat = torch.cat(rest_mat, dim=2)
    target_mat = torch.cat(target_mat, dim=2)
    S = rest_mat.bmm(target_mat.transpose(1, 2))

    mask_zero = S.sum(dim=(1, 2))

    S_non_zero = S[mask_zero != 0].reshape(-1, 3, 3)

    U, _, V = torch.svd(S_non_zero)

    rot_mat = torch.zeros_like(S)
    rot_mat[mask_zero == 0] = torch.eye(3, device=S.device)

    rot_mat_non_zero = torch.bmm(V, U.transpose(1, 2))
    rot_mat[mask_zero != 0] = rot_mat_non_zero

    assert torch.sum(torch.isnan(rot_mat)) == 0, ('rot_mat', rot_mat)

    return rot_mat


def batch_get_pelvis_orient(rel_pose_skeleton, rel_rest_pose, parents,
                            children, dtype):
    """Get pelvis orientation for batch data.

    Args:
        rel_pose_skeleton (torch.tensor):
            Locations of root-normalized pose skeleton with shape (Bx29x3)
        rel_rest_pose (torch.tensor):
            Locations of rest/ template pose with shape (Bx29x3)
        parents (List[int]): list of indexes of kinematic parents with len 29
        children (List[int]): list of indexes of kinematic children with len 29
        dtype (torch.dtype, optional):
            Data type of the created tensors, the default is torch.float32

    Returns:
        rot_mat (torch.tensor):
            Rotation matrix of pelvis with shape (Bx3x3)
    """
    batch_size = rel_pose_skeleton.shape[0]
    device = rel_pose_skeleton.device

    assert children[0] == 3
    pelvis_child = [int(children[0])]
    for i in range(1, parents.shape[0]):
        if parents[i] == 0 and i not in pelvis_child:
            pelvis_child.append(i)

    spine_final_loc = rel_pose_skeleton[:, int(children[0])].clone()
    spine_rest_loc = rel_rest_pose[:, int(children[0])].clone()
    # spine_norm = torch.norm(spine_final_loc, dim=1, keepdim=True)
    # spine_norm = spine_final_loc / (spine_norm + 1e-8)

    # rot_mat_spine = vectors2rotmat(spine_rest_loc, spine_final_loc, dtype)

    # (B, 1, 1)
    vec_final_norm = torch.norm(spine_final_loc, dim=1, keepdim=True)
    vec_rest_norm = torch.norm(spine_rest_loc, dim=1, keepdim=True)

    spine_norm = spine_final_loc / (vec_final_norm + 1e-8)

    # (B, 3, 1)
    axis = torch.cross(spine_rest_loc, spine_final_loc, dim=1)
    axis_norm = torch.norm(axis, dim=1, keepdim=True)
    axis = axis / (axis_norm + 1e-8)
    angle = torch.arccos(
        torch.sum(spine_rest_loc * spine_final_loc, dim=1, keepdim=True) /
        (vec_rest_norm * vec_final_norm + 1e-8))
    axis_angle = (angle * axis).squeeze()
    # aa to rotmat
    rot_mat_spine = aa_to_rotmat(axis_angle)

    assert torch.sum(torch.isnan(rot_mat_spine)) == 0, ('rot_mat_spine',
                                                        rot_mat_spine)
    center_final_loc = 0
    center_rest_loc = 0
    for child in pelvis_child:
        if child == int(children[0]):
            continue
        center_final_loc = center_final_loc + rel_pose_skeleton[:,
                                                                child].clone()
        center_rest_loc = center_rest_loc + rel_rest_pose[:, child].clone()
    center_final_loc = center_final_loc / (len(pelvis_child) - 1)
    center_rest_loc = center_rest_loc / (len(pelvis_child) - 1)

    center_rest_loc = torch.matmul(rot_mat_spine, center_rest_loc)

    center_final_loc = center_final_loc - torch.sum(
        center_final_loc * spine_norm, dim=1, keepdim=True) * spine_norm
    center_rest_loc = center_rest_loc - torch.sum(
        center_rest_loc * spine_norm, dim=1, keepdim=True) * spine_norm

    center_final_loc_norm = torch.norm(center_final_loc, dim=1, keepdim=True)
    center_rest_loc_norm = torch.norm(center_rest_loc, dim=1, keepdim=True)

    # (B, 3, 1)
    axis = torch.cross(center_rest_loc, center_final_loc, dim=1)
    axis_norm = torch.norm(axis, dim=1, keepdim=True)

    # (B, 1, 1)
    cos = torch.sum(
        center_rest_loc * center_final_loc, dim=1,
        keepdim=True) / (center_rest_loc_norm * center_final_loc_norm + 1e-8)
    sin = axis_norm / (center_rest_loc_norm * center_final_loc_norm + 1e-8)

    assert torch.sum(torch.isnan(cos)) == 0, ('cos', cos)
    assert torch.sum(torch.isnan(sin)) == 0, ('sin', sin)
    # (B, 3, 1)
    axis = axis / (axis_norm + 1e-8)

    # Convert location revolve to rot_mat by rodrigues
    # (B, 1, 1)
    rx, ry, rz = torch.split(axis, 1, dim=1)
    zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device)

    K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
        .view((batch_size, 3, 3))
    ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
    rot_mat_center = ident + sin * K + (1 - cos) * torch.bmm(K, K)

    rot_mat = torch.matmul(rot_mat_center, rot_mat_spine)

    return rot_mat


def batch_get_3children_orient_svd(rel_pose_skeleton, rel_rest_pose,
                                   rot_mat_chain_parent, children_list, dtype):
    """Get pelvis orientation for batch data.

    Args:
        rel_pose_skeleton (torch.tensor):
            Locations of root-normalized pose skeleton with shape (Bx29x3)
        rel_rest_pose (torch.tensor):
            Locations of rest/ template pose with shape (Bx29x3)
        rot_mat_chain_parents (torch.tensor):
            parent's rotation matrix with shape (Bx3x3)
        children (List[int]): list of indexes of kinematic children with len 29
        dtype (torch.dtype, optional):
            Data type of the created tensors, the default is torch.float32

    Returns:
        rot_mat (torch.tensor):
            Child's rotation matrix with shape (Bx3x3)
    """
    rest_mat = []
    target_mat = []
    for c, child in enumerate(children_list):
        if isinstance(rel_pose_skeleton, list):
            target = rel_pose_skeleton[c].clone()
            template = rel_rest_pose[c].clone()
        else:
            target = rel_pose_skeleton[:, child].clone()
            template = rel_rest_pose[:, child].clone()

        target = torch.matmul(rot_mat_chain_parent.transpose(1, 2), target)

        target_mat.append(target)
        rest_mat.append(template)

    rest_mat = torch.cat(rest_mat, dim=2)
    target_mat = torch.cat(target_mat, dim=2)
    S = rest_mat.bmm(target_mat.transpose(1, 2))

    U, _, V = torch.svd(S)

    rot_mat = torch.bmm(V, U.transpose(1, 2))
    assert torch.sum(torch.isnan(rot_mat)) == 0, ('3children rot_mat', rot_mat)
    return rot_mat