AiOS / detrsmpl /models /utils /fits_dict.py
ttxskk
update
d7e58f0
# ------------------------------------------------------------------------------
# Adapted from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
# Original licence please see docs/additional_licenses.md
# ------------------------------------------------------------------------------
import os
import cv2
import numpy as np
import torch
from detrsmpl.utils.transforms import aa_to_rotmat
train_datasets = ['h36m', 'mpi_inf_3dhp', 'lsp', 'lspet', 'mpii', 'coco']
static_fits_load_dir = 'data/static_fits'
save_dir = 'data/spin_fits'
# Permutation of SMPL pose parameters when flipping the shape
SMPL_JOINTS_FLIP_PERM = [
0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21,
20, 23, 22
]
SMPL_POSE_FLIP_PERM = []
for i in SMPL_JOINTS_FLIP_PERM:
SMPL_POSE_FLIP_PERM.append(3 * i)
SMPL_POSE_FLIP_PERM.append(3 * i + 1)
SMPL_POSE_FLIP_PERM.append(3 * i + 2)
class FitsDict():
"""Dictionary keeping track of the best fit per image in the training set.
Ref: https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
"""
def __init__(self, fits='static') -> None:
assert fits in ['static', 'final']
self.fits = fits
self.fits_dict = {}
# array used to flip SMPL pose parameters
self.flipped_parts = torch.tensor(SMPL_POSE_FLIP_PERM,
dtype=torch.int64)
# Load dictionary state
# for ds_name, ds in train_dataset.dataset_dict.items():
for ds_name in train_datasets:
# h36m has gt so no static fits
if ds_name == 'h36m' or self.fits == 'static':
dict_file = os.path.join(static_fits_load_dir,
ds_name + '_fits.npy')
content = np.load(dict_file)
self.fits_dict[ds_name] = torch.from_numpy(content)
del content
elif self.fits == 'final':
dict_file = os.path.join('data/final_fits', ds_name + '.npz')
# load like this to save mem
content = np.load(dict_file)
pose = torch.from_numpy(content['pose'])
betas = torch.from_numpy(content['betas'])
del content
params = torch.cat([pose, betas], dim=-1)
self.fits_dict[ds_name] = params
def save(self):
"""Save dictionary state to disk."""
for ds_name in train_datasets:
dict_file = os.path.join(save_dir, ds_name + '_fits.npy')
np.save(dict_file, self.fits_dict[ds_name].cpu().numpy())
def __getitem__(self, x):
"""Retrieve dictionary entries."""
dataset_name, ind, rot, is_flipped = x
batch_size = len(dataset_name)
pose = torch.zeros((batch_size, 72))
betas = torch.zeros((batch_size, 10))
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
params = self.fits_dict[ds][i]
pose[n, :] = params[:72]
betas[n, :] = params[72:]
pose = pose.clone()
# Apply flipping and rotation
pose = self.rotate_pose(self.flip_pose(pose, is_flipped), rot)
betas = betas.clone()
return pose, betas
def __setitem__(self, x, val):
"""Update dictionary entries."""
dataset_name, ind, rot, is_flipped, update = x
pose, betas = val
batch_size = len(dataset_name)
# Undo flipping and rotation
pose = self.flip_pose(self.rotate_pose(pose, -rot), is_flipped)
params = torch.cat((pose, betas), dim=-1).cpu()
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
if update[n]:
self.fits_dict[ds][i] = params[n]
def flip_pose(self, pose, is_flipped):
"""flip SMPL pose parameters."""
is_flipped = is_flipped.bool()
pose_f = pose.clone()
pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts]
# we also negate the second and the third dimension of the
# axis-angle representation
pose_f[is_flipped, 1::3] *= -1
pose_f[is_flipped, 2::3] *= -1
return pose_f
def rotate_pose(self, pose, rot):
"""Rotate SMPL pose parameters by rot degrees."""
pose = pose.clone()
cos = torch.cos(-np.pi * rot / 180.)
sin = torch.sin(-np.pi * rot / 180.)
zeros = torch.zeros_like(cos)
r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device)
r3[:, 0, -1] = 1
R = torch.cat([
torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1),
torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3
],
dim=1)
global_pose = pose[:, :3]
global_pose_rotmat = R @ aa_to_rotmat(global_pose)
global_pose_rotmat = global_pose_rotmat.cpu().numpy()
global_pose_np = np.zeros((global_pose.shape[0], 3))
for i in range(global_pose.shape[0]):
aa, _ = cv2.Rodrigues(global_pose_rotmat[i])
global_pose_np[i, :] = aa.squeeze()
pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device)
return pose