AiOS / detrsmpl /data /datasets /human_image_smplx_dataset.py
ttxskk
update
d7e58f0
raw
history blame
17.8 kB
import os
import os.path
import pickle
from collections import OrderedDict
from typing import List, Optional, Union
import numpy as np
import torch
from detrsmpl.core.conventions.keypoints_mapping import (
get_keypoint_idx,
get_keypoint_idxs_by_part,
)
from detrsmpl.core.evaluation import fg_vertices_to_mesh_distance
from detrsmpl.utils.transforms import aa_to_rotmat
from .builder import DATASETS
from .human_image_dataset import HumanImageDataset
@DATASETS.register_module()
class HumanImageSMPLXDataset(HumanImageDataset):
# metric
ALLOWED_METRICS = {
'mpjpe', 'pa-mpjpe', 'pve', '3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc',
'3DRMSE', 'pa-pve'
}
def __init__(
self,
data_prefix: str,
pipeline: list,
dataset_name: str,
body_model: Optional[Union[dict, None]] = None,
ann_file: Optional[Union[str, None]] = None,
convention: Optional[str] = 'human_data',
cache_data_path: Optional[Union[str, None]] = None,
test_mode: Optional[bool] = False,
num_betas: Optional[int] = 10,
num_expression: Optional[int] = 10,
face_vertex_ids_path: Optional[str] = None,
hand_vertex_ids_path: Optional[str] = None,
):
super().__init__(data_prefix, pipeline, dataset_name, body_model,
ann_file, convention, cache_data_path, test_mode)
self.num_betas = num_betas
self.num_expression = num_expression
if face_vertex_ids_path is not None:
if os.path.exists(face_vertex_ids_path):
self.face_vertex_ids = np.load(face_vertex_ids_path).astype(
np.int32)
if hand_vertex_ids_path is not None:
if os.path.exists(hand_vertex_ids_path):
with open(hand_vertex_ids_path, 'rb') as f:
vertex_idxs_data = pickle.load(f, encoding='latin1')
self.left_hand_vertex_ids = vertex_idxs_data['left_hand']
self.right_hand_vertex_ids = vertex_idxs_data['right_hand']
def prepare_raw_data(self, idx: int):
"""Get item from self.human_data."""
info = super().prepare_raw_data(idx)
if self.cache_reader is not None:
self.human_data = self.cache_reader.get_item(idx)
idx = idx % self.cache_reader.slice_size
if 'smplx' in self.human_data:
smplx_dict = self.human_data['smplx']
info['has_smplx'] = 1
else:
smplx_dict = {}
info['has_smplx'] = 0
if 'global_orient' in smplx_dict:
info['smplx_global_orient'] = smplx_dict['global_orient'][idx]
info['has_smplx_global_orient'] = 1
else:
info['smplx_global_orient'] = np.zeros((3), dtype=np.float32)
info['has_smplx_global_orient'] = 0
if 'body_pose' in smplx_dict:
info['smplx_body_pose'] = smplx_dict['body_pose'][idx]
info['has_smplx_body_pose'] = 1
else:
info['smplx_body_pose'] = np.zeros((21, 3), dtype=np.float32)
info['has_smplx_body_pose'] = 0
if 'right_hand_pose' in smplx_dict:
info['smplx_right_hand_pose'] = smplx_dict['right_hand_pose'][idx]
info['has_smplx_right_hand_pose'] = 1
else:
info['smplx_right_hand_pose'] = np.zeros((15, 3), dtype=np.float32)
info['has_smplx_right_hand_pose'] = 0
if 'left_hand_pose' in smplx_dict:
info['smplx_left_hand_pose'] = smplx_dict['left_hand_pose'][idx]
info['has_smplx_left_hand_pose'] = 1
else:
info['smplx_left_hand_pose'] = np.zeros((15, 3), dtype=np.float32)
info['has_smplx_left_hand_pose'] = 0
if 'jaw_pose' in smplx_dict:
info['smplx_jaw_pose'] = smplx_dict['jaw_pose'][idx]
info['has_smplx_jaw_pose'] = 1
else:
info['smplx_jaw_pose'] = np.zeros((3), dtype=np.float32)
info['has_smplx_jaw_pose'] = 0
if 'betas' in smplx_dict:
info['smplx_betas'] = smplx_dict['betas'][idx]
info['has_smplx_betas'] = 1
else:
info['smplx_betas'] = np.zeros((self.num_betas), dtype=np.float32)
info['has_smplx_betas'] = 0
if 'expression' in smplx_dict:
info['smplx_expression'] = smplx_dict['expression'][idx]
info['has_smplx_expression'] = 1
else:
info['smplx_expression'] = np.zeros((self.num_expression),
dtype=np.float32)
info['has_smplx_expression'] = 0
return info
def _parse_result(self, res, mode='keypoint', body_part=''):
if mode == 'vertice':
# pred
pred_vertices = res['vertices'] * 1000.
# gt
if 'vertices' in self.human_data: # stirling or ehf
gt_vertices = self.human_data['vertices'].copy()
if self.dataset_name == 'EHF':
gt_vertices = gt_vertices * 1000.
else:
gt_param_dict = self.human_data['smplx'].copy()
for key, value in gt_param_dict.items():
new_value = torch.FloatTensor(value)
if ('pose' in key or key
== 'global_orient') and value.shape[-2] != 3:
new_value = aa_to_rotmat(new_value)
gt_param_dict[key] = new_value
gt_output = self.body_model(**gt_param_dict)
gt_vertices = gt_output['vertices'].detach().cpu().numpy(
) * 1000.
if body_part == 'right_hand':
pred_vertices = pred_vertices[:, self.right_hand_vertex_ids]
gt_vertices = gt_vertices[:, self.right_hand_vertex_ids]
elif body_part == 'left_hand':
pred_vertices = pred_vertices[:, self.left_hand_vertex_ids]
gt_vertices = gt_vertices[:, self.left_hand_vertex_ids]
elif body_part == 'face':
pred_vertices = pred_vertices[:, self.face_vertex_ids]
gt_vertices = gt_vertices[:, self.face_vertex_ids]
gt_mask = np.ones(gt_vertices.shape[:-1])
assert len(pred_vertices) == self.num_data
return pred_vertices, gt_vertices, gt_mask
elif mode == 'keypoint':
pred_keypoints3d = res['keypoints']
assert len(pred_keypoints3d) == self.num_data
if self.dataset_name in {'pw3d', '3DPW', '3dpw'}:
betas = []
body_pose = []
global_orient = []
gender = []
smpl_dict = self.human_data['smpl']
for idx in range(self.num_data):
betas.append(smpl_dict['betas'][idx])
body_pose.append(smpl_dict['body_pose'][idx])
global_orient.append(smpl_dict['global_orient'][idx])
if self.human_data['meta']['gender'][idx] == 'm':
gender.append(0)
else:
gender.append(1)
betas = torch.FloatTensor(betas)
body_pose = torch.FloatTensor(body_pose).view(-1, 69)
global_orient = torch.FloatTensor(global_orient)
gender = torch.Tensor(gender)
gt_output = self.body_model(betas=betas,
body_pose=body_pose,
global_orient=global_orient,
gender=gender)
gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
gt_keypoints3d_mask = np.ones(
(len(pred_keypoints3d), gt_keypoints3d.shape[1]))
elif self.dataset_name == 'EHF':
gt_vertices = self.human_data['vertices'].copy()
if body_part == 'J14':
gt_keypoints3d = torch.einsum('bik,ji->bjk', [
torch.from_numpy(gt_vertices).float(),
self.body_model.joints_regressor
]).numpy()
pred_vertices = res['vertices']
pred_keypoints3d = torch.einsum('bik,ji->bjk', [
torch.from_numpy(pred_vertices).float(),
self.body_model.joints_regressor
]).numpy()
gt_keypoints3d_mask = np.ones(
(len(pred_keypoints3d), gt_keypoints3d.shape[1]))
else:
gt_keypoints3d = torch.einsum('bik,ji->bjk', [
torch.from_numpy(gt_vertices).float(),
self.body_model.J_regressor
]).numpy()
extra_joints_idxs = np.array([
9120, 9929, 9448, 616, 6, 5770, 5780, 8846, 8463, 8474,
8635, 5361, 4933, 5058, 5169, 5286, 8079, 7669, 7794,
7905, 8022
])
gt_keypoints3d = np.concatenate(
(gt_keypoints3d, gt_vertices[:, extra_joints_idxs]),
axis=1)
pred_vertices = res['vertices']
pred_keypoints3d = torch.einsum('bik,ji->bjk', [
torch.from_numpy(pred_vertices).float(),
self.body_model.J_regressor
]).numpy()
pred_keypoints3d = np.concatenate(
(pred_keypoints3d, pred_vertices[:,
extra_joints_idxs]),
axis=1)
idxs = list(range(0, gt_keypoints3d.shape[1]))
if body_part == 'right_hand':
idxs = get_keypoint_idxs_by_part(
'right_hand', self.convention)
idxs.append(
get_keypoint_idx('right_wrist', self.convention))
elif body_part == 'left_hand':
idxs = get_keypoint_idxs_by_part(
'left_hand', self.convention)
idxs.append(
get_keypoint_idx('left_wrist', self.convention))
elif body_part == 'body':
idxs = get_keypoint_idxs_by_part(
'body', self.convention)
gt_keypoints3d = gt_keypoints3d[:, idxs]
pred_keypoints3d = pred_keypoints3d[:, idxs]
gt_keypoints3d_mask = np.ones(
(len(pred_keypoints3d), gt_keypoints3d.shape[1]))
else:
gt_keypoints3d = self.human_data['keypoints3d'][:, :, :3]
gt_keypoints3d_mask = np.ones(
(len(pred_keypoints3d), gt_keypoints3d.shape[1]))
if gt_keypoints3d.shape[1] == 17:
# SMPLX_to_J14
assert pred_keypoints3d.shape[1] == 14
H36M_TO_J17 = [
6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9
]
H36M_TO_J14 = H36M_TO_J17[:14]
joint_mapper = H36M_TO_J14
gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
pred_pelvis = pred_keypoints3d[:,
[2, 3], :].mean(axis=1,
keepdims=True)
gt_pelvis = gt_keypoints3d[:, [2, 3], :].mean(axis=1,
keepdims=True)
gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper]
pred_keypoints3d = pred_keypoints3d - pred_pelvis
gt_keypoints3d = gt_keypoints3d - gt_pelvis
elif gt_keypoints3d.shape[1] == 14:
assert pred_keypoints3d.shape[1] == 14
pred_pelvis = pred_keypoints3d[:,
[2, 3], :].mean(axis=1,
keepdims=True)
gt_pelvis = gt_keypoints3d[:, [2, 3], :].mean(axis=1,
keepdims=True)
pred_keypoints3d = pred_keypoints3d - pred_pelvis
gt_keypoints3d = gt_keypoints3d - gt_pelvis
elif gt_keypoints3d.shape[1] == 21:
pred_pelvis = pred_keypoints3d[:, :1, :]
gt_pelvis = gt_keypoints3d[:, :1, :]
pred_keypoints3d = pred_keypoints3d - pred_pelvis
gt_keypoints3d = gt_keypoints3d - gt_pelvis
else:
pass
pred_keypoints3d = pred_keypoints3d * 1000
if self.dataset_name != 'stirling':
gt_keypoints3d = gt_keypoints3d * 1000
gt_keypoints3d_mask = gt_keypoints3d_mask > 0
return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask
def _report_3d_rmse(self, res_file):
"""compute the 3DRMSE between a predicted 3D face shape and the 3D
ground truth scan."""
pred_vertices, gt_vertices, _ = self._parse_result(res_file,
mode='vertice')
pred_keypoints3d, gt_keypoints3d, _ = self._parse_result(
res_file, mode='keypoint')
errors = []
for pred_vertice, gt_vertice, pred_points, gt_points in zip(
pred_vertices, gt_vertices, pred_keypoints3d, gt_keypoints3d):
error = fg_vertices_to_mesh_distance(gt_vertice, gt_points,
pred_vertice,
self.body_model.faces,
pred_points)
errors.append(error)
error = np.array(errors).mean()
name_value_tuples = [('3DRMSE', error)]
return name_value_tuples
def evaluate(self,
outputs: list,
res_folder: str,
metric: Optional[Union[str, List[str]]] = 'pa-mpjpe',
**kwargs: dict):
"""Evaluate 3D keypoint results.
Args:
outputs (list): results from model inference.
res_folder (str): path to store results.
metric (Optional[Union[str, List(str)]]):
the type of metric. Default: 'pa-mpjpe'
kwargs (dict): other arguments.
Returns:
dict:
A dict of all evaluation results.
"""
metrics = metric if isinstance(metric, list) else [metric]
for metric in metrics:
if metric not in self.ALLOWED_METRICS:
raise KeyError(f'metric {metric} is not supported')
# for keeping correctness during multi-gpu test, we sort all results
res_dict = {}
for out in outputs:
target_id = out['image_idx']
batch_size = len(out['keypoints_3d'])
for i in range(batch_size):
res_dict[int(target_id[i])] = dict(
keypoints=out['keypoints_3d'][i],
vertices=out['vertices'][i],
)
keypoints, vertices = [], []
for i in range(self.num_data):
keypoints.append(res_dict[i]['keypoints'])
vertices.append(res_dict[i]['vertices'])
keypoints = np.stack(keypoints)
vertices = np.stack(vertices)
res = dict(keypoints=keypoints, vertices=vertices)
name_value_tuples = []
for index, _metric in enumerate(metrics):
if 'body_part' in kwargs:
body_parts = kwargs['body_part'][index]
for body_part in body_parts:
if _metric == 'pa-mpjpe':
_nv_tuples = self._report_mpjpe(res,
metric='pa-mpjpe',
body_part=body_part)
elif _metric == 'pa-pve':
_nv_tuples = self._report_pve(res,
metric='pa-pve',
body_part=body_part)
else:
raise NotImplementedError
name_value_tuples.extend(_nv_tuples)
else:
if _metric == 'mpjpe':
_nv_tuples = self._report_mpjpe(res)
elif _metric == 'pa-mpjpe':
_nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe')
elif _metric == '3dpck':
_nv_tuples = self._report_3d_pck(res)
elif _metric == 'pa-3dpck':
_nv_tuples = self._report_3d_pck(res, metric='pa-3dpck')
elif _metric == '3dauc':
_nv_tuples = self._report_3d_auc(res)
elif _metric == 'pa-3dauc':
_nv_tuples = self._report_3d_auc(res, metric='pa-3dauc')
elif _metric == 'pve':
_nv_tuples = self._report_pve(res)
elif _metric == 'pa-pve':
_nv_tuples = self._report_pve(res, metric='pa-pve')
elif _metric == '3DRMSE':
_nv_tuples = self._report_3d_rmse(res)
else:
raise NotImplementedError
name_value_tuples.extend(_nv_tuples)
name_value = OrderedDict(name_value_tuples)
return name_value