Spaces:
Running
on
L40S
Running
on
L40S
# ------------------------------------------------------------------------------ | |
# Adapted from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py | |
# Original licence please see docs/additional_licenses.md | |
# ------------------------------------------------------------------------------ | |
import os | |
import cv2 | |
import numpy as np | |
import torch | |
from detrsmpl.utils.transforms import aa_to_rotmat | |
train_datasets = ['h36m', 'mpi_inf_3dhp', 'lsp', 'lspet', 'mpii', 'coco'] | |
static_fits_load_dir = 'data/static_fits' | |
save_dir = 'data/spin_fits' | |
# Permutation of SMPL pose parameters when flipping the shape | |
SMPL_JOINTS_FLIP_PERM = [ | |
0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, | |
20, 23, 22 | |
] | |
SMPL_POSE_FLIP_PERM = [] | |
for i in SMPL_JOINTS_FLIP_PERM: | |
SMPL_POSE_FLIP_PERM.append(3 * i) | |
SMPL_POSE_FLIP_PERM.append(3 * i + 1) | |
SMPL_POSE_FLIP_PERM.append(3 * i + 2) | |
class FitsDict(): | |
"""Dictionary keeping track of the best fit per image in the training set. | |
Ref: https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py | |
""" | |
def __init__(self, fits='static') -> None: | |
assert fits in ['static', 'final'] | |
self.fits = fits | |
self.fits_dict = {} | |
# array used to flip SMPL pose parameters | |
self.flipped_parts = torch.tensor(SMPL_POSE_FLIP_PERM, | |
dtype=torch.int64) | |
# Load dictionary state | |
# for ds_name, ds in train_dataset.dataset_dict.items(): | |
for ds_name in train_datasets: | |
# h36m has gt so no static fits | |
if ds_name == 'h36m' or self.fits == 'static': | |
dict_file = os.path.join(static_fits_load_dir, | |
ds_name + '_fits.npy') | |
content = np.load(dict_file) | |
self.fits_dict[ds_name] = torch.from_numpy(content) | |
del content | |
elif self.fits == 'final': | |
dict_file = os.path.join('data/final_fits', ds_name + '.npz') | |
# load like this to save mem | |
content = np.load(dict_file) | |
pose = torch.from_numpy(content['pose']) | |
betas = torch.from_numpy(content['betas']) | |
del content | |
params = torch.cat([pose, betas], dim=-1) | |
self.fits_dict[ds_name] = params | |
def save(self): | |
"""Save dictionary state to disk.""" | |
for ds_name in train_datasets: | |
dict_file = os.path.join(save_dir, ds_name + '_fits.npy') | |
np.save(dict_file, self.fits_dict[ds_name].cpu().numpy()) | |
def __getitem__(self, x): | |
"""Retrieve dictionary entries.""" | |
dataset_name, ind, rot, is_flipped = x | |
batch_size = len(dataset_name) | |
pose = torch.zeros((batch_size, 72)) | |
betas = torch.zeros((batch_size, 10)) | |
for ds, i, n in zip(dataset_name, ind, range(batch_size)): | |
params = self.fits_dict[ds][i] | |
pose[n, :] = params[:72] | |
betas[n, :] = params[72:] | |
pose = pose.clone() | |
# Apply flipping and rotation | |
pose = self.rotate_pose(self.flip_pose(pose, is_flipped), rot) | |
betas = betas.clone() | |
return pose, betas | |
def __setitem__(self, x, val): | |
"""Update dictionary entries.""" | |
dataset_name, ind, rot, is_flipped, update = x | |
pose, betas = val | |
batch_size = len(dataset_name) | |
# Undo flipping and rotation | |
pose = self.flip_pose(self.rotate_pose(pose, -rot), is_flipped) | |
params = torch.cat((pose, betas), dim=-1).cpu() | |
for ds, i, n in zip(dataset_name, ind, range(batch_size)): | |
if update[n]: | |
self.fits_dict[ds][i] = params[n] | |
def flip_pose(self, pose, is_flipped): | |
"""flip SMPL pose parameters.""" | |
is_flipped = is_flipped.bool() | |
pose_f = pose.clone() | |
pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts] | |
# we also negate the second and the third dimension of the | |
# axis-angle representation | |
pose_f[is_flipped, 1::3] *= -1 | |
pose_f[is_flipped, 2::3] *= -1 | |
return pose_f | |
def rotate_pose(self, pose, rot): | |
"""Rotate SMPL pose parameters by rot degrees.""" | |
pose = pose.clone() | |
cos = torch.cos(-np.pi * rot / 180.) | |
sin = torch.sin(-np.pi * rot / 180.) | |
zeros = torch.zeros_like(cos) | |
r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device) | |
r3[:, 0, -1] = 1 | |
R = torch.cat([ | |
torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1), | |
torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3 | |
], | |
dim=1) | |
global_pose = pose[:, :3] | |
global_pose_rotmat = R @ aa_to_rotmat(global_pose) | |
global_pose_rotmat = global_pose_rotmat.cpu().numpy() | |
global_pose_np = np.zeros((global_pose.shape[0], 3)) | |
for i in range(global_pose.shape[0]): | |
aa, _ = cv2.Rodrigues(global_pose_rotmat[i]) | |
global_pose_np[i, :] = aa.squeeze() | |
pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device) | |
return pose | |