Spaces:
Sleeping
Sleeping
import json | |
import os | |
import os.path | |
from abc import ABCMeta | |
from collections import OrderedDict | |
from typing import Any, List, Optional, Union | |
import mmcv | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from mmcv.runner import get_dist_info | |
from detrsmpl.core.conventions.keypoints_mapping import ( | |
convert_kps, | |
get_keypoint_num, | |
get_mapping, | |
) | |
from detrsmpl.core.evaluation import ( | |
keypoint_3d_auc, | |
keypoint_3d_pck, | |
keypoint_mpjpe, | |
vertice_pve, | |
) | |
from detrsmpl.data.data_structures.multi_human_data import MultiHumanData | |
from detrsmpl.models.body_models.builder import build_body_model | |
from .base_dataset import BaseDataset | |
from .builder import DATASETS | |
class MultiHumanImageDataset(BaseDataset, metaclass=ABCMeta): | |
def __init__(self, | |
data_prefix: str, | |
pipeline: list, | |
body_model: Optional[Union[dict, None]] = None, | |
ann_file: Optional[Union[str, None]] = None, | |
convention: Optional[str] = 'human_data', | |
test_mode: Optional[bool] = False, | |
dataset_name: Optional[Union[str, None]] = None): | |
self.num_keypoints = get_keypoint_num(convention) | |
self.convention = convention | |
super(MultiHumanImageDataset, | |
self).__init__(data_prefix, pipeline, ann_file, test_mode, | |
dataset_name) | |
if body_model is not None: | |
self.body_model = build_body_model(body_model) | |
else: | |
self.body_model = None | |
def get_annotation_file(self): | |
"""Get path of the annotation file.""" | |
ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets') | |
self.ann_file = os.path.join(ann_prefix, self.ann_file) | |
def load_annotations(self): | |
"""Load annotations.""" | |
self.get_annotation_file() | |
self.human_data = MultiHumanData() | |
self.human_data.load(self.ann_file) | |
self.instance_num = self.human_data.instance_num | |
self.image_path = self.human_data['image_path'] | |
self.num_data = self.human_data.data_len | |
try: | |
self.frame_range = self.human_data['frame_range'] | |
except KeyError: | |
self.frame_range = \ | |
np.array([[i, i + 1] for i in range(self.num_data)]) | |
self.num_data = self.frame_range.shape[0] | |
if self.human_data.check_keypoints_compressed(): | |
self.human_data.decompress_keypoints() | |
# change keypoint from 'human_data' to the given convention | |
if 'keypoints3d_ori' in self.human_data: | |
keypoints3d_ori = self.human_data['keypoints3d_ori'] | |
assert 'keypoints3d_ori_mask' in self.human_data | |
keypoints3d_ori_mask = self.human_data['keypoints3d_ori_mask'] | |
keypoints3d_ori, keypoints3d_ori_mask = \ | |
convert_kps( | |
keypoints3d_ori, | |
src='human_data', | |
dst=self.convention, | |
mask=keypoints3d_ori_mask) | |
self.human_data.__setitem__('keypoints3d_ori', keypoints3d_ori) | |
self.human_data.__setitem__('keypoints3d_ori_convention', | |
self.convention) | |
self.human_data.__setitem__('keypoints3d_ori_mask', | |
keypoints3d_ori_mask) | |
elif 'keypoints3d' in self.human_data: | |
keypoints3d_ori = self.human_data['keypoints3d'] | |
assert 'keypoints3d_mask' in self.human_data | |
keypoints3d_ori_mask = self.human_data['keypoints3d_mask'] | |
keypoints3d_ori, keypoints3d_ori_mask = \ | |
convert_kps( | |
keypoints3d_ori, | |
src='human_data', | |
dst=self.convention, | |
mask=keypoints3d_ori_mask) | |
self.human_data.__setitem__('keypoints3d_ori', keypoints3d_ori) | |
self.human_data.__setitem__('keypoints3d_ori_convention', | |
self.convention) | |
self.human_data.__setitem__('keypoints3d_ori_mask', | |
keypoints3d_ori_mask) | |
if 'keypoints2d_ori' in self.human_data: | |
keypoints2d_ori = self.human_data['keypoints2d_ori'] | |
assert 'keypoints2d_ori_mask' in self.human_data | |
keypoints2d_ori_mask = self.human_data['keypoints2d_ori_mask'] | |
keypoints2d_ori, keypoints2d_ori_mask = \ | |
convert_kps( | |
keypoints2d_ori, | |
src='human_data', | |
dst=self.convention, | |
mask=keypoints2d_ori_mask) | |
self.human_data.__setitem__('keypoints2d_ori', keypoints2d_ori) | |
self.human_data.__setitem__('keypoints2d_ori_convention', | |
self.convention) | |
self.human_data.__setitem__('keypoints2d_ori_mask', | |
keypoints2d_ori_mask) | |
ori_mask = keypoints2d_ori[:, :, 2] | |
elif 'keypoints2d' in self.human_data: | |
keypoints2d_ori = self.human_data['keypoints2d'] | |
assert 'keypoints2d_mask' in self.human_data | |
keypoints2d_ori_mask = self.human_data['keypoints2d_mask'] | |
keypoints2d_ori, keypoints2d_ori_mask = \ | |
convert_kps( | |
keypoints2d_ori, | |
src='human_data', | |
dst=self.convention, | |
mask=keypoints2d_ori_mask) | |
self.human_data.__setitem__('keypoints2d_ori', keypoints2d_ori) | |
self.human_data.__setitem__('keypoints2d_ori_convention', | |
self.convention) | |
self.human_data.__setitem__('keypoints2d_ori_mask', | |
keypoints2d_ori_mask) | |
# if 'has_smpl' in self.human_data: | |
# index = ori_mask.sum(-1)>=8 | |
# self.human_data['has_smpl']=self.human_data['has_smpl'][:147270]*index | |
# change keypoint from 'human_data' to the given convention | |
if 'keypoints3d_smpl' in self.human_data: | |
keypoints3d_smpl = self.human_data['keypoints3d_smpl'] | |
assert 'keypoints3d_smpl_mask' in self.human_data | |
keypoints3d_smpl_mask = self.human_data['keypoints3d_smpl_mask'] | |
keypoints3d_smpl, keypoints3d_smpl_mask = \ | |
convert_kps( | |
keypoints3d_smpl, | |
src='human_data', | |
dst=self.convention, | |
mask=keypoints3d_smpl_mask) | |
# index = ori_mask.sum(-1)<8 | |
# index = ori_mask.sum(-1)<8 | |
# keypoints3d_smpl[index]=np.concatenate( | |
# [keypoints3d_smpl[index][:,:,:3], | |
# keypoints2d_ori[index][:,:,[2]]], | |
# -1) | |
self.human_data.__setitem__('keypoints3d_smpl', keypoints3d_smpl) | |
self.human_data.__setitem__('keypoints3d_smpl_convention', | |
self.convention) | |
self.human_data.__setitem__('keypoints3d_smpl_mask', | |
keypoints3d_smpl_mask) | |
if 'keypoints2d_smpl' in self.human_data: | |
keypoints2d_smpl = self.human_data['keypoints2d_smpl'] | |
assert 'keypoints2d_smpl_mask' in self.human_data | |
keypoints2d_smpl_mask = self.human_data['keypoints2d_smpl_mask'] | |
keypoints2d_smpl, keypoints2d_smpl_mask = \ | |
convert_kps( | |
keypoints2d_smpl, | |
src='human_data', | |
dst=self.convention, | |
mask=keypoints2d_smpl_mask) | |
# index = ori_mask.sum(-1)<8 | |
# keypoints2d_smpl[index]=np.concatenate( | |
# [keypoints2d_smpl[index][:,:,:2], | |
# keypoints2d_ori[index][:,:,[2]]], | |
# -1) | |
# keypoints2d_smpl[index][:,:,2]=keypoints2d_ori[index][:, :,2] | |
self.human_data.__setitem__('keypoints2d_smpl', keypoints2d_smpl) | |
self.human_data.__setitem__('keypoints2d_smpl_convention', | |
self.convention) | |
self.human_data.__setitem__('keypoints2d_smpl_mask', | |
keypoints2d_smpl_mask) | |
self.human_data.compress_keypoints_by_mask() | |
def prepare_raw_data(self, idx: int): | |
"""Get item from self.human_data.""" | |
sample_idx = idx | |
frame_start, frame_end = self.frame_range[idx] | |
frame_num = frame_end - frame_start | |
# TODO: Support cache_reader? | |
info = {} | |
info['img_prefix'] = None | |
image_path = self.human_data['image_path'][frame_start] | |
info['image_path'] = os.path.join(self.data_prefix, 'datasets', | |
self.dataset_name, image_path) | |
# TODO: Support smc? | |
info['dataset_name'] = self.dataset_name | |
info['sample_idx'] = sample_idx | |
if 'bbox_xywh' in self.human_data: | |
info['bbox_xywh'] = self.human_data['bbox_xywh'][ | |
frame_start:frame_end] | |
center, scale = [], [] | |
for bbox in info['bbox_xywh']: | |
x, y, w, h, s = bbox | |
cx = x + w / 2 | |
cy = y + h / 2 | |
# TODO: verify if we should keep w = h = max(w, h) for multi human data | |
w = h = max(w, h) | |
center.append([cx, cy]) | |
scale.append([w, h]) | |
info['center'] = np.array(center) | |
info['scale'] = np.array(scale) | |
else: | |
info['bbox_xywh'] = np.zeros((frame_num, 5)) | |
info['center'] = np.zeros((frame_num, 2)) | |
info['scale'] = np.zeros((frame_num, 2)) | |
if 'keypoints2d_ori' in self.human_data: | |
info['keypoints2d_ori'] = self.human_data['keypoints2d_ori'][ | |
frame_start:frame_end] | |
conf = info['keypoints2d_ori'][..., -1].sum(-1) > 0 | |
info['has_keypoints2d_ori'] = np.ones( | |
(frame_num, 1)) * conf[..., None] | |
else: | |
info['keypoints2d_ori'] = np.zeros( | |
(frame_num, self.num_keypoints, 3)) | |
info['has_keypoints2d_ori'] = np.zeros((frame_num, 1)) | |
if 'keypoints3d_ori' in self.human_data: | |
info['keypoints3d_ori'] = self.human_data['keypoints3d_ori'][ | |
frame_start:frame_end] | |
conf = info['keypoints3d_ori'][..., -1].sum(-1) > 0 | |
info['has_keypoints3d_ori'] = np.ones( | |
(frame_num, 1)) * conf[..., None] | |
else: | |
info['keypoints3d_ori'] = np.zeros( | |
(frame_num, self.num_keypoints, 4)) | |
info['has_keypoints3d_ori'] = np.zeros((frame_num, 1)) | |
if 'keypoints2d_smpl' in self.human_data: | |
info['keypoints2d_smpl'] = self.human_data['keypoints2d_smpl'][ | |
frame_start:frame_end] | |
conf = info['keypoints2d_smpl'][..., -1].sum(-1) > 0 | |
info['has_keypoints2d_smpl'] = np.ones( | |
(frame_num, 1)) * conf[..., None] | |
else: | |
info['keypoints2d_smpl'] = np.zeros( | |
(frame_num, self.num_keypoints, 3)) | |
info['has_keypoints2d_smpl'] = np.zeros((frame_num, 1)) | |
if 'keypoints3d_smpl' in self.human_data: | |
info['keypoints3d_smpl'] = self.human_data['keypoints3d_smpl'][ | |
frame_start:frame_end] | |
conf = info['keypoints3d_smpl'][..., -1].sum(-1) > 0 | |
info['has_keypoints3d_smpl'] = np.ones( | |
(frame_num, 1)) * conf[..., None] | |
else: | |
info['keypoints3d_smpl'] = np.zeros( | |
(frame_num, self.num_keypoints, 4)) | |
info['has_keypoints3d_smpl'] = np.zeros((frame_num, 1)) | |
if 'smpl' in self.human_data: | |
if 'has_smpl' in self.human_data: | |
info['has_smpl'] = \ | |
self.human_data['has_smpl'][frame_start:frame_end] | |
else: | |
info['has_smpl'] = np.ones((frame_num, 1)) | |
smpl_dict = self.human_data['smpl'] | |
else: | |
info['has_smpl'] = np.zeros((frame_num, 1)) | |
smpl_dict = {} | |
if 'body_pose' in smpl_dict: | |
info['smpl_body_pose'] = smpl_dict['body_pose'][ | |
frame_start:frame_end] | |
else: | |
info['smpl_body_pose'] = np.zeros((frame_num, 23, 3)) | |
if 'global_orient' in smpl_dict: | |
info['smpl_global_orient'] = smpl_dict['global_orient'][ | |
frame_start:frame_end] | |
else: | |
info['smpl_global_orient'] = np.zeros((frame_num, 3)) | |
if 'betas' in smpl_dict: | |
info['smpl_betas'] = smpl_dict['betas'][frame_start:frame_end] | |
else: | |
info['smpl_betas'] = np.zeros((frame_num, 10)) | |
if 'transl' in smpl_dict: | |
info['smpl_transl'] = smpl_dict['transl'][frame_start:frame_end] | |
else: | |
info['smpl_transl'] = np.zeros((frame_num, 3)) | |
if 'area' in self.human_data: | |
info['area'] = self.human_data['area'][frame_start:frame_end] | |
else: | |
info['area'] = np.zeros((frame_num, 0)) | |
return info | |
def prepare_data(self, idx: int): | |
"""Generate and transform data.""" | |
info = self.prepare_raw_data(idx) | |
return self.pipeline(info) | |
def evaluate(self, | |
outputs: list, | |
res_folder: str, | |
metric: Optional[Union[str, List[str]]] = 'pa-mpjpe', | |
**kwargs: dict): | |
"""Evaluate 3D keypoint results. | |
Args: | |
outputs (list): results from model inference. | |
res_folder (str): path to store results. | |
metric (Optional[Union[str, List(str)]]): | |
the type of metric. Default: 'pa-mpjpe' | |
kwargs (dict): other arguments. | |
Returns: | |
dict: | |
A dict of all evaluation results. | |
""" | |
metrics = metric if isinstance(metric, list) else [metric] | |
for metric in metrics: | |
if metric not in self.ALLOWED_METRICS: | |
raise KeyError(f'metric {metric} is not supported') | |
res_file = os.path.join(res_folder, 'result_keypoints.json') | |
# for keeping correctness during multi-gpu test, we sort all results | |
res_dict = {} | |
# 'scores', 'labels', 'boxes', 'keypoints', 'pred_smpl_pose', | |
# 'pred_smpl_beta', 'pred_smpl_cam', 'pred_smpl_kp3d', | |
# 'gt_smpl_pose', 'gt_smpl_beta', 'gt_smpl_kp3d', 'gt_boxes', | |
# 'gt_keypoints', 'image_idx' | |
for out in outputs: | |
target_id = out['image_idx'] | |
batch_size = len(out['pred_smpl_kp3d']) | |
for i in range(batch_size): | |
res_dict[int(target_id[i])] = dict( | |
keypoints=out['pred_smpl_kp3d'][i], | |
gt_poses=out['gt_smpl_pose'][i], | |
gt_betas=out['gt_smpl_beta'][i], | |
pred_poses=out['pred_smpl_pose'][i], | |
pred_betas=out['pred_smpl_beta'][i]) | |
keypoints, gt_poses, gt_betas, pred_poses, pred_betas = \ | |
[], [], [], [], [] | |
# print(self.num_data) | |
for i in range(self.num_data): | |
keypoints.append(res_dict[i]['keypoints']) | |
gt_poses.append(res_dict[i]['gt_poses']) | |
gt_betas.append(res_dict[i]['gt_betas']) | |
pred_poses.append(res_dict[i]['pred_poses']) | |
pred_betas.append(res_dict[i]['pred_betas']) | |
res = dict(keypoints=keypoints, | |
gt_poses=gt_poses, | |
gt_betas=gt_betas, | |
pred_poses=pred_poses, | |
pred_betas=pred_betas) | |
# mmcv.dump(res, res_file) | |
name_value_tuples = [] | |
for _metric in metrics: | |
if _metric == 'mpjpe': | |
_nv_tuples = self._report_mpjpe(res) | |
elif _metric == 'pa-mpjpe': | |
_nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe') | |
print(_nv_tuples) | |
elif _metric == '3dpck': | |
_nv_tuples = self._report_3d_pck(res) | |
elif _metric == 'pa-3dpck': | |
_nv_tuples = self._report_3d_pck(res, metric='pa-3dpck') | |
elif _metric == '3dauc': | |
_nv_tuples = self._report_3d_auc(res) | |
elif _metric == 'pa-3dauc': | |
_nv_tuples = self._report_3d_auc(res, metric='pa-3dauc') | |
elif _metric == 'pve': | |
_nv_tuples = self._report_pve(res) | |
elif _metric == 'ihmr': | |
_nv_tuples = self._report_ihmr(res) | |
else: | |
raise NotImplementedError | |
name_value_tuples.extend(_nv_tuples) | |
name_value = OrderedDict(name_value_tuples) | |
return name_value | |
def _write_keypoint_results(keypoints: Any, res_file: str): | |
"""Write results into a json file.""" | |
with open(res_file, 'w') as f: | |
json.dump(keypoints, f, sort_keys=True, indent=4) | |
def _parse_result(self, res, mode='keypoint', body_part=None): | |
"""Parse results.""" | |
if mode == 'vertice': | |
# gt | |
gt_beta, gt_pose, gt_global_orient, gender = [], [], [], [] | |
gt_smpl_dict = self.human_data['smpl'] | |
for idx in range(self.num_data): | |
gt_beta.append(gt_smpl_dict['betas'][idx]) | |
gt_pose.append(gt_smpl_dict['body_pose'][idx]) | |
gt_global_orient.append(gt_smpl_dict['global_orient'][idx]) | |
if self.human_data['meta']['gender'][idx] == 'm': | |
gender.append(0) | |
else: | |
gender.append(1) | |
gt_beta = torch.FloatTensor(gt_beta) | |
gt_pose = torch.FloatTensor(gt_pose).view(-1, 69) | |
gt_global_orient = torch.FloatTensor(gt_global_orient) | |
gender = torch.Tensor(gender) | |
gt_output = self.body_model(betas=gt_beta, | |
body_pose=gt_pose, | |
global_orient=gt_global_orient, | |
gender=gender) | |
gt_vertices = gt_output['vertices'].detach().cpu().numpy() * 1000. | |
gt_mask = np.ones(gt_vertices.shape[:-1]) | |
# pred | |
pred_pose = torch.FloatTensor(res['pred_poses']) | |
pred_beta = torch.FloatTensor(res['pred_betas']) | |
pred_output = self.body_model( | |
betas=pred_beta[:, 0], | |
body_pose=pred_pose[:, 0, 1:], | |
global_orient=pred_pose[:, 0, 0].unsqueeze(1), | |
pose2rot=False) | |
pred_vertices = pred_output['vertices'].detach().cpu().numpy( | |
) * 1000. | |
assert len(pred_vertices) == self.num_data | |
return pred_vertices, gt_vertices, gt_mask | |
elif mode == 'keypoint': | |
pred_keypoints3d = res['keypoints'] | |
assert len(pred_keypoints3d) == self.num_data | |
# (B, 17, 3) | |
pred_keypoints3d = np.array(pred_keypoints3d).reshape( | |
len(pred_keypoints3d), -1, 3) | |
# pred_keypoints3d,_ = convert_kps( | |
# pred_keypoints3d, | |
# src='smpl_54', | |
# dst='h36m', | |
# ) | |
gt_smpl_pose = np.array(res['gt_poses']) | |
gt_body_pose = gt_smpl_pose[..., 1:, :] | |
gt_global_orient = gt_smpl_pose[..., 0, :] | |
gt_betas = np.array(res['gt_betas']) | |
gender = np.zeros([gt_betas.shape[0], gt_betas.shape[1]]) | |
if self.dataset_name == 'pw3d': | |
# betas = [] | |
# body_pose = [] | |
# global_orient = [] | |
# gender = [] | |
# smpl_dict = self.human_data['smpl'] | |
# for idx in range(self.num_data): | |
# betas.append(smpl_dict['betas'][idx]) | |
# body_pose.append(smpl_dict['body_pose'][idx]) | |
# global_orient.append(smpl_dict['global_orient'][idx]) | |
# if self.human_data['meta']['gender'][idx] == 'm': | |
# gender.append(0) | |
# else: | |
# gender.append(1) | |
betas = torch.FloatTensor(gt_betas).view(-1, 10) | |
body_pose = torch.FloatTensor(gt_body_pose).view(-1, 69) | |
global_orient = torch.FloatTensor(gt_global_orient).view(-1, 3) | |
gender = torch.Tensor(gender).view(-1) | |
gt_output = self.body_model(betas=betas, | |
body_pose=body_pose, | |
global_orient=global_orient, | |
gender=gender) | |
gt_keypoints3d = gt_output['joints'].detach().cpu().numpy() | |
# gt_keypoints3d,_ = convert_kps( | |
# gt_keypoints3d, | |
# src='smpl_54', | |
# dst='h36m') | |
gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17)) | |
elif self.dataset_name == 'h36m': | |
_, h36m_idxs, _ = get_mapping('human_data', 'h36m') | |
gt_keypoints3d = \ | |
self.human_data['keypoints3d'][:, h36m_idxs, :3] | |
gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17)) | |
elif self.dataset_name == 'humman': | |
betas = [] | |
body_pose = [] | |
global_orient = [] | |
smpl_dict = self.human_data['smpl'] | |
for idx in range(self.num_data): | |
betas.append(smpl_dict['betas'][idx]) | |
body_pose.append(smpl_dict['body_pose'][idx]) | |
global_orient.append(smpl_dict['global_orient'][idx]) | |
betas = torch.FloatTensor(betas) | |
body_pose = torch.FloatTensor(body_pose).view(-1, 69) | |
global_orient = torch.FloatTensor(global_orient) | |
gt_output = self.body_model(betas=betas, | |
body_pose=body_pose, | |
global_orient=global_orient) | |
gt_keypoints3d = gt_output['joints'].detach().cpu().numpy() | |
gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24)) | |
else: | |
raise NotImplementedError() | |
# SMPL_49 only! | |
if gt_keypoints3d.shape[1] == 49: | |
assert pred_keypoints3d.shape[1] == 49 | |
gt_keypoints3d = gt_keypoints3d[:, 25:, :] | |
pred_keypoints3d = pred_keypoints3d[:, 25:, :] | |
joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] | |
gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
# we only evaluate on 14 lsp joints | |
pred_pelvis = (pred_keypoints3d[:, 2] + | |
pred_keypoints3d[:, 3]) / 2 | |
gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2 | |
# H36M for testing! | |
elif gt_keypoints3d.shape[1] == 17: | |
assert pred_keypoints3d.shape[-2] == 17 | |
H36M_TO_J17 = [ | |
6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9 | |
] | |
H36M_TO_J14 = H36M_TO_J17[:14] | |
joint_mapper = H36M_TO_J14 | |
pred_pelvis = pred_keypoints3d[:, 0] | |
gt_pelvis = gt_keypoints3d[:, 0] | |
gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
# keypoint 24 | |
elif gt_keypoints3d.shape[1] == 24: | |
assert pred_keypoints3d.shape[1] == 24 | |
joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] | |
gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
# we only evaluate on 14 lsp joints | |
pred_pelvis = (pred_keypoints3d[:, 2] + | |
pred_keypoints3d[:, 3]) / 2 | |
gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2 | |
else: | |
pass | |
pred_keypoints3d = (pred_keypoints3d - | |
pred_pelvis[:, None, :]) * 1000 | |
gt_keypoints3d = (gt_keypoints3d - gt_pelvis[:, None, :]) * 1000 | |
gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper] > 0 | |
return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask | |
def _report_mpjpe(self, res_file, metric='mpjpe', body_part=''): | |
"""Cauculate mean per joint position error (MPJPE) or its variants PA- | |
MPJPE. | |
Report mean per joint position error (MPJPE) and mean per joint | |
position error after rigid alignment (PA-MPJPE) | |
""" | |
pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
self._parse_result(res_file, mode='keypoint', body_part=body_part) | |
err_name = metric.upper() | |
if body_part != '': | |
err_name = body_part.upper() + ' ' + err_name | |
if metric == 'mpjpe': | |
alignment = 'none' | |
elif metric == 'pa-mpjpe': | |
alignment = 'procrustes' | |
else: | |
raise ValueError(f'Invalid metric: {metric}') | |
error = keypoint_mpjpe(pred_keypoints3d, gt_keypoints3d, | |
gt_keypoints3d_mask, alignment) | |
info_str = [(err_name, error)] | |
return info_str | |
def _report_3d_pck(self, res_file, metric='3dpck'): | |
"""Cauculate Percentage of Correct Keypoints (3DPCK) w. or w/o | |
Procrustes alignment. | |
Args: | |
keypoint_results (list): Keypoint predictions. See | |
'Body3DMpiInf3dhpDataset.evaluate' for details. | |
metric (str): Specify mpjpe variants. Supported options are: | |
- ``'3dpck'``: Standard 3DPCK. | |
- ``'pa-3dpck'``: | |
3DPCK after aligning prediction to groundtruth | |
via a rigid transformation (scale, rotation and | |
translation). | |
""" | |
pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
self._parse_result(res_file) | |
err_name = metric.upper() | |
if metric == '3dpck': | |
alignment = 'none' | |
elif metric == 'pa-3dpck': | |
alignment = 'procrustes' | |
else: | |
raise ValueError(f'Invalid metric: {metric}') | |
error = keypoint_3d_pck(pred_keypoints3d, gt_keypoints3d, | |
gt_keypoints3d_mask, alignment) | |
name_value_tuples = [(err_name, error)] | |
return name_value_tuples | |
def _report_3d_auc(self, res_file, metric='3dauc'): | |
"""Cauculate the Area Under the Curve (AUC) computed for a range of | |
3DPCK thresholds. | |
Args: | |
keypoint_results (list): Keypoint predictions. See | |
'Body3DMpiInf3dhpDataset.evaluate' for details. | |
metric (str): Specify mpjpe variants. Supported options are: | |
- ``'3dauc'``: Standard 3DAUC. | |
- ``'pa-3dauc'``: 3DAUC after aligning prediction to | |
groundtruth via a rigid transformation (scale, rotation and | |
translation). | |
""" | |
pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
self._parse_result(res_file) | |
err_name = metric.upper() | |
if metric == '3dauc': | |
alignment = 'none' | |
elif metric == 'pa-3dauc': | |
alignment = 'procrustes' | |
else: | |
raise ValueError(f'Invalid metric: {metric}') | |
error = keypoint_3d_auc(pred_keypoints3d, gt_keypoints3d, | |
gt_keypoints3d_mask, alignment) | |
name_value_tuples = [(err_name, error)] | |
return name_value_tuples | |
def _report_pve(self, res_file, metric='pve', body_part=''): | |
"""Cauculate per vertex error.""" | |
pred_verts, gt_verts, _ = \ | |
self._parse_result(res_file, mode='vertice', body_part=body_part) | |
err_name = metric.upper() | |
if body_part != '': | |
err_name = body_part.upper() + ' ' + err_name | |
if metric == 'pve': | |
alignment = 'none' | |
elif metric == 'pa-pve': | |
alignment = 'procrustes' | |
else: | |
raise ValueError(f'Invalid metric: {metric}') | |
error = vertice_pve(pred_verts, gt_verts, alignment) | |
return [(err_name, error)] | |
def _report_ihmr(self, res_file): | |
"""Calculate IHMR metric. | |
https://arxiv.org/abs/2203.16427 | |
""" | |
pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
self._parse_result(res_file, mode='keypoint') | |
pred_verts, gt_verts, _ = \ | |
self._parse_result(res_file, mode='vertice') | |
from detrsmpl.utils.geometry import rot6d_to_rotmat | |
mean_param_path = 'data/body_models/smpl_mean_params.npz' | |
mean_params = np.load(mean_param_path) | |
mean_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) | |
mean_shape = torch.from_numpy( | |
mean_params['shape'][:].astype('float32')).unsqueeze(0) | |
mean_pose = rot6d_to_rotmat(mean_pose).view(1, 24, 3, 3) | |
mean_output = self.body_model(betas=mean_shape, | |
body_pose=mean_pose[:, 1:], | |
global_orient=mean_pose[:, :1], | |
pose2rot=False) | |
mean_verts = mean_output['vertices'].detach().cpu().numpy() * 1000. | |
dis = (gt_verts - mean_verts) * (gt_verts - mean_verts) | |
dis = np.sqrt(dis.sum(axis=-1)).mean(axis=-1) | |
# from the most remote one to the nearest one | |
idx_order = np.argsort(dis)[::-1] | |
num_data = idx_order.shape[0] | |
def report_ihmr_idx(idx): | |
mpvpe = vertice_pve(pred_verts[idx], gt_verts[idx]) | |
mpjpe = keypoint_mpjpe(pred_keypoints3d[idx], gt_keypoints3d[idx], | |
gt_keypoints3d_mask[idx], 'none') | |
pampjpe = keypoint_mpjpe(pred_keypoints3d[idx], | |
gt_keypoints3d[idx], | |
gt_keypoints3d_mask[idx], 'procrustes') | |
return (mpvpe, mpjpe, pampjpe) | |
def report_ihmr_tail(percentage): | |
cur_data = int(num_data * percentage / 100.0) | |
idx = idx_order[:cur_data] | |
mpvpe, mpjpe, pampjpe = report_ihmr_idx(idx) | |
res_mpvpe = ('bMPVPE Tail ' + str(percentage) + '%', mpvpe) | |
res_mpjpe = ('bMPJPE Tail ' + str(percentage) + '%', mpjpe) | |
res_pampjpe = ('bPA-MPJPE Tail ' + str(percentage) + '%', pampjpe) | |
return [res_mpvpe, res_mpjpe, res_pampjpe] | |
def report_ihmr_all(num_bin): | |
num_per_bin = np.array([0 for _ in range(num_bin) | |
]).astype(np.float32) | |
sum_mpvpe = np.array([0 | |
for _ in range(num_bin)]).astype(np.float32) | |
sum_mpjpe = np.array([0 | |
for _ in range(num_bin)]).astype(np.float32) | |
sum_pampjpe = np.array([0 for _ in range(num_bin) | |
]).astype(np.float32) | |
max_dis = dis[idx_order[0]] | |
min_dis = dis[idx_order[-1]] | |
delta = (max_dis - min_dis) / num_bin | |
for i in range(num_data): | |
idx = int((dis[i] - min_dis) / delta - 0.001) | |
res_mpvpe, res_mpjpe, res_pampjpe = report_ihmr_idx([i]) | |
num_per_bin[idx] += 1 | |
sum_mpvpe[idx] += res_mpvpe | |
sum_mpjpe[idx] += res_mpjpe | |
sum_pampjpe[idx] += res_pampjpe | |
for i in range(num_bin): | |
if num_per_bin[i] > 0: | |
sum_mpvpe[i] = sum_mpvpe[i] / num_per_bin[i] | |
sum_mpjpe[i] = sum_mpjpe[i] / num_per_bin[i] | |
sum_pampjpe[i] = sum_pampjpe[i] / num_per_bin[i] | |
valid_idx = np.where(num_per_bin > 0) | |
res_mpvpe = ('bMPVPE All', sum_mpvpe[valid_idx].mean()) | |
res_mpjpe = ('bMPJPE All', sum_mpjpe[valid_idx].mean()) | |
res_pampjpe = ('bPA-MPJPE All', sum_pampjpe[valid_idx].mean()) | |
return [res_mpvpe, res_mpjpe, res_pampjpe] | |
metrics = [] | |
metrics.extend(report_ihmr_all(num_bin=100)) | |
metrics.extend(report_ihmr_tail(percentage=10)) | |
metrics.extend(report_ihmr_tail(percentage=5)) | |
return metrics | |