File size: 5,186 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# ------------------------------------------------------------------------------
# Adapted from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
# Original licence please see docs/additional_licenses.md
# ------------------------------------------------------------------------------

import os

import cv2
import numpy as np
import torch

from detrsmpl.utils.transforms import aa_to_rotmat

train_datasets = ['h36m', 'mpi_inf_3dhp', 'lsp', 'lspet', 'mpii', 'coco']
static_fits_load_dir = 'data/static_fits'
save_dir = 'data/spin_fits'

# Permutation of SMPL pose parameters when flipping the shape
SMPL_JOINTS_FLIP_PERM = [
    0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21,
    20, 23, 22
]
SMPL_POSE_FLIP_PERM = []
for i in SMPL_JOINTS_FLIP_PERM:
    SMPL_POSE_FLIP_PERM.append(3 * i)
    SMPL_POSE_FLIP_PERM.append(3 * i + 1)
    SMPL_POSE_FLIP_PERM.append(3 * i + 2)


class FitsDict():
    """Dictionary keeping track of the best fit per image in the training set.

    Ref: https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
    """
    def __init__(self, fits='static') -> None:
        assert fits in ['static', 'final']
        self.fits = fits
        self.fits_dict = {}

        # array used to flip SMPL pose parameters
        self.flipped_parts = torch.tensor(SMPL_POSE_FLIP_PERM,
                                          dtype=torch.int64)
        # Load dictionary state
        # for ds_name, ds in train_dataset.dataset_dict.items():
        for ds_name in train_datasets:

            # h36m has gt so no static fits
            if ds_name == 'h36m' or self.fits == 'static':
                dict_file = os.path.join(static_fits_load_dir,
                                         ds_name + '_fits.npy')
                content = np.load(dict_file)
                self.fits_dict[ds_name] = torch.from_numpy(content)
                del content
            elif self.fits == 'final':
                dict_file = os.path.join('data/final_fits', ds_name + '.npz')
                # load like this to save mem
                content = np.load(dict_file)
                pose = torch.from_numpy(content['pose'])
                betas = torch.from_numpy(content['betas'])
                del content
                params = torch.cat([pose, betas], dim=-1)
                self.fits_dict[ds_name] = params

    def save(self):
        """Save dictionary state to disk."""
        for ds_name in train_datasets:
            dict_file = os.path.join(save_dir, ds_name + '_fits.npy')
            np.save(dict_file, self.fits_dict[ds_name].cpu().numpy())

    def __getitem__(self, x):
        """Retrieve dictionary entries."""
        dataset_name, ind, rot, is_flipped = x
        batch_size = len(dataset_name)
        pose = torch.zeros((batch_size, 72))
        betas = torch.zeros((batch_size, 10))
        for ds, i, n in zip(dataset_name, ind, range(batch_size)):
            params = self.fits_dict[ds][i]
            pose[n, :] = params[:72]
            betas[n, :] = params[72:]
        pose = pose.clone()

        # Apply flipping and rotation
        pose = self.rotate_pose(self.flip_pose(pose, is_flipped), rot)

        betas = betas.clone()
        return pose, betas

    def __setitem__(self, x, val):
        """Update dictionary entries."""
        dataset_name, ind, rot, is_flipped, update = x
        pose, betas = val
        batch_size = len(dataset_name)

        # Undo flipping and rotation
        pose = self.flip_pose(self.rotate_pose(pose, -rot), is_flipped)

        params = torch.cat((pose, betas), dim=-1).cpu()
        for ds, i, n in zip(dataset_name, ind, range(batch_size)):
            if update[n]:
                self.fits_dict[ds][i] = params[n]

    def flip_pose(self, pose, is_flipped):
        """flip SMPL pose parameters."""
        is_flipped = is_flipped.bool()
        pose_f = pose.clone()
        pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts]
        # we also negate the second and the third dimension of the
        # axis-angle representation
        pose_f[is_flipped, 1::3] *= -1
        pose_f[is_flipped, 2::3] *= -1
        return pose_f

    def rotate_pose(self, pose, rot):
        """Rotate SMPL pose parameters by rot degrees."""
        pose = pose.clone()
        cos = torch.cos(-np.pi * rot / 180.)
        sin = torch.sin(-np.pi * rot / 180.)
        zeros = torch.zeros_like(cos)
        r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device)
        r3[:, 0, -1] = 1
        R = torch.cat([
            torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1),
            torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3
        ],
                      dim=1)
        global_pose = pose[:, :3]
        global_pose_rotmat = R @ aa_to_rotmat(global_pose)
        global_pose_rotmat = global_pose_rotmat.cpu().numpy()
        global_pose_np = np.zeros((global_pose.shape[0], 3))
        for i in range(global_pose.shape[0]):
            aa, _ = cv2.Rodrigues(global_pose_rotmat[i])
            global_pose_np[i, :] = aa.squeeze()
        pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device)
        return pose