AiOS / detrsmpl /data /datasets /multi_human_image_dataset.py
ttxskk
update
d7e58f0
raw
history blame
33 kB
import json
import os
import os.path
from abc import ABCMeta
from collections import OrderedDict
from typing import Any, List, Optional, Union
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from detrsmpl.core.conventions.keypoints_mapping import (
convert_kps,
get_keypoint_num,
get_mapping,
)
from detrsmpl.core.evaluation import (
keypoint_3d_auc,
keypoint_3d_pck,
keypoint_mpjpe,
vertice_pve,
)
from detrsmpl.data.data_structures.multi_human_data import MultiHumanData
from detrsmpl.models.body_models.builder import build_body_model
from .base_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register_module()
class MultiHumanImageDataset(BaseDataset, metaclass=ABCMeta):
def __init__(self,
data_prefix: str,
pipeline: list,
body_model: Optional[Union[dict, None]] = None,
ann_file: Optional[Union[str, None]] = None,
convention: Optional[str] = 'human_data',
test_mode: Optional[bool] = False,
dataset_name: Optional[Union[str, None]] = None):
self.num_keypoints = get_keypoint_num(convention)
self.convention = convention
super(MultiHumanImageDataset,
self).__init__(data_prefix, pipeline, ann_file, test_mode,
dataset_name)
if body_model is not None:
self.body_model = build_body_model(body_model)
else:
self.body_model = None
def get_annotation_file(self):
"""Get path of the annotation file."""
ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets')
self.ann_file = os.path.join(ann_prefix, self.ann_file)
def load_annotations(self):
"""Load annotations."""
self.get_annotation_file()
self.human_data = MultiHumanData()
self.human_data.load(self.ann_file)
self.instance_num = self.human_data.instance_num
self.image_path = self.human_data['image_path']
self.num_data = self.human_data.data_len
try:
self.frame_range = self.human_data['frame_range']
except KeyError:
self.frame_range = \
np.array([[i, i + 1] for i in range(self.num_data)])
self.num_data = self.frame_range.shape[0]
if self.human_data.check_keypoints_compressed():
self.human_data.decompress_keypoints()
# change keypoint from 'human_data' to the given convention
if 'keypoints3d_ori' in self.human_data:
keypoints3d_ori = self.human_data['keypoints3d_ori']
assert 'keypoints3d_ori_mask' in self.human_data
keypoints3d_ori_mask = self.human_data['keypoints3d_ori_mask']
keypoints3d_ori, keypoints3d_ori_mask = \
convert_kps(
keypoints3d_ori,
src='human_data',
dst=self.convention,
mask=keypoints3d_ori_mask)
self.human_data.__setitem__('keypoints3d_ori', keypoints3d_ori)
self.human_data.__setitem__('keypoints3d_ori_convention',
self.convention)
self.human_data.__setitem__('keypoints3d_ori_mask',
keypoints3d_ori_mask)
elif 'keypoints3d' in self.human_data:
keypoints3d_ori = self.human_data['keypoints3d']
assert 'keypoints3d_mask' in self.human_data
keypoints3d_ori_mask = self.human_data['keypoints3d_mask']
keypoints3d_ori, keypoints3d_ori_mask = \
convert_kps(
keypoints3d_ori,
src='human_data',
dst=self.convention,
mask=keypoints3d_ori_mask)
self.human_data.__setitem__('keypoints3d_ori', keypoints3d_ori)
self.human_data.__setitem__('keypoints3d_ori_convention',
self.convention)
self.human_data.__setitem__('keypoints3d_ori_mask',
keypoints3d_ori_mask)
if 'keypoints2d_ori' in self.human_data:
keypoints2d_ori = self.human_data['keypoints2d_ori']
assert 'keypoints2d_ori_mask' in self.human_data
keypoints2d_ori_mask = self.human_data['keypoints2d_ori_mask']
keypoints2d_ori, keypoints2d_ori_mask = \
convert_kps(
keypoints2d_ori,
src='human_data',
dst=self.convention,
mask=keypoints2d_ori_mask)
self.human_data.__setitem__('keypoints2d_ori', keypoints2d_ori)
self.human_data.__setitem__('keypoints2d_ori_convention',
self.convention)
self.human_data.__setitem__('keypoints2d_ori_mask',
keypoints2d_ori_mask)
ori_mask = keypoints2d_ori[:, :, 2]
elif 'keypoints2d' in self.human_data:
keypoints2d_ori = self.human_data['keypoints2d']
assert 'keypoints2d_mask' in self.human_data
keypoints2d_ori_mask = self.human_data['keypoints2d_mask']
keypoints2d_ori, keypoints2d_ori_mask = \
convert_kps(
keypoints2d_ori,
src='human_data',
dst=self.convention,
mask=keypoints2d_ori_mask)
self.human_data.__setitem__('keypoints2d_ori', keypoints2d_ori)
self.human_data.__setitem__('keypoints2d_ori_convention',
self.convention)
self.human_data.__setitem__('keypoints2d_ori_mask',
keypoints2d_ori_mask)
# if 'has_smpl' in self.human_data:
# index = ori_mask.sum(-1)>=8
# self.human_data['has_smpl']=self.human_data['has_smpl'][:147270]*index
# change keypoint from 'human_data' to the given convention
if 'keypoints3d_smpl' in self.human_data:
keypoints3d_smpl = self.human_data['keypoints3d_smpl']
assert 'keypoints3d_smpl_mask' in self.human_data
keypoints3d_smpl_mask = self.human_data['keypoints3d_smpl_mask']
keypoints3d_smpl, keypoints3d_smpl_mask = \
convert_kps(
keypoints3d_smpl,
src='human_data',
dst=self.convention,
mask=keypoints3d_smpl_mask)
# index = ori_mask.sum(-1)<8
# index = ori_mask.sum(-1)<8
# keypoints3d_smpl[index]=np.concatenate(
# [keypoints3d_smpl[index][:,:,:3],
# keypoints2d_ori[index][:,:,[2]]],
# -1)
self.human_data.__setitem__('keypoints3d_smpl', keypoints3d_smpl)
self.human_data.__setitem__('keypoints3d_smpl_convention',
self.convention)
self.human_data.__setitem__('keypoints3d_smpl_mask',
keypoints3d_smpl_mask)
if 'keypoints2d_smpl' in self.human_data:
keypoints2d_smpl = self.human_data['keypoints2d_smpl']
assert 'keypoints2d_smpl_mask' in self.human_data
keypoints2d_smpl_mask = self.human_data['keypoints2d_smpl_mask']
keypoints2d_smpl, keypoints2d_smpl_mask = \
convert_kps(
keypoints2d_smpl,
src='human_data',
dst=self.convention,
mask=keypoints2d_smpl_mask)
# index = ori_mask.sum(-1)<8
# keypoints2d_smpl[index]=np.concatenate(
# [keypoints2d_smpl[index][:,:,:2],
# keypoints2d_ori[index][:,:,[2]]],
# -1)
# keypoints2d_smpl[index][:,:,2]=keypoints2d_ori[index][:, :,2]
self.human_data.__setitem__('keypoints2d_smpl', keypoints2d_smpl)
self.human_data.__setitem__('keypoints2d_smpl_convention',
self.convention)
self.human_data.__setitem__('keypoints2d_smpl_mask',
keypoints2d_smpl_mask)
self.human_data.compress_keypoints_by_mask()
def prepare_raw_data(self, idx: int):
"""Get item from self.human_data."""
sample_idx = idx
frame_start, frame_end = self.frame_range[idx]
frame_num = frame_end - frame_start
# TODO: Support cache_reader?
info = {}
info['img_prefix'] = None
image_path = self.human_data['image_path'][frame_start]
info['image_path'] = os.path.join(self.data_prefix, 'datasets',
self.dataset_name, image_path)
# TODO: Support smc?
info['dataset_name'] = self.dataset_name
info['sample_idx'] = sample_idx
if 'bbox_xywh' in self.human_data:
info['bbox_xywh'] = self.human_data['bbox_xywh'][
frame_start:frame_end]
center, scale = [], []
for bbox in info['bbox_xywh']:
x, y, w, h, s = bbox
cx = x + w / 2
cy = y + h / 2
# TODO: verify if we should keep w = h = max(w, h) for multi human data
w = h = max(w, h)
center.append([cx, cy])
scale.append([w, h])
info['center'] = np.array(center)
info['scale'] = np.array(scale)
else:
info['bbox_xywh'] = np.zeros((frame_num, 5))
info['center'] = np.zeros((frame_num, 2))
info['scale'] = np.zeros((frame_num, 2))
if 'keypoints2d_ori' in self.human_data:
info['keypoints2d_ori'] = self.human_data['keypoints2d_ori'][
frame_start:frame_end]
conf = info['keypoints2d_ori'][..., -1].sum(-1) > 0
info['has_keypoints2d_ori'] = np.ones(
(frame_num, 1)) * conf[..., None]
else:
info['keypoints2d_ori'] = np.zeros(
(frame_num, self.num_keypoints, 3))
info['has_keypoints2d_ori'] = np.zeros((frame_num, 1))
if 'keypoints3d_ori' in self.human_data:
info['keypoints3d_ori'] = self.human_data['keypoints3d_ori'][
frame_start:frame_end]
conf = info['keypoints3d_ori'][..., -1].sum(-1) > 0
info['has_keypoints3d_ori'] = np.ones(
(frame_num, 1)) * conf[..., None]
else:
info['keypoints3d_ori'] = np.zeros(
(frame_num, self.num_keypoints, 4))
info['has_keypoints3d_ori'] = np.zeros((frame_num, 1))
if 'keypoints2d_smpl' in self.human_data:
info['keypoints2d_smpl'] = self.human_data['keypoints2d_smpl'][
frame_start:frame_end]
conf = info['keypoints2d_smpl'][..., -1].sum(-1) > 0
info['has_keypoints2d_smpl'] = np.ones(
(frame_num, 1)) * conf[..., None]
else:
info['keypoints2d_smpl'] = np.zeros(
(frame_num, self.num_keypoints, 3))
info['has_keypoints2d_smpl'] = np.zeros((frame_num, 1))
if 'keypoints3d_smpl' in self.human_data:
info['keypoints3d_smpl'] = self.human_data['keypoints3d_smpl'][
frame_start:frame_end]
conf = info['keypoints3d_smpl'][..., -1].sum(-1) > 0
info['has_keypoints3d_smpl'] = np.ones(
(frame_num, 1)) * conf[..., None]
else:
info['keypoints3d_smpl'] = np.zeros(
(frame_num, self.num_keypoints, 4))
info['has_keypoints3d_smpl'] = np.zeros((frame_num, 1))
if 'smpl' in self.human_data:
if 'has_smpl' in self.human_data:
info['has_smpl'] = \
self.human_data['has_smpl'][frame_start:frame_end]
else:
info['has_smpl'] = np.ones((frame_num, 1))
smpl_dict = self.human_data['smpl']
else:
info['has_smpl'] = np.zeros((frame_num, 1))
smpl_dict = {}
if 'body_pose' in smpl_dict:
info['smpl_body_pose'] = smpl_dict['body_pose'][
frame_start:frame_end]
else:
info['smpl_body_pose'] = np.zeros((frame_num, 23, 3))
if 'global_orient' in smpl_dict:
info['smpl_global_orient'] = smpl_dict['global_orient'][
frame_start:frame_end]
else:
info['smpl_global_orient'] = np.zeros((frame_num, 3))
if 'betas' in smpl_dict:
info['smpl_betas'] = smpl_dict['betas'][frame_start:frame_end]
else:
info['smpl_betas'] = np.zeros((frame_num, 10))
if 'transl' in smpl_dict:
info['smpl_transl'] = smpl_dict['transl'][frame_start:frame_end]
else:
info['smpl_transl'] = np.zeros((frame_num, 3))
if 'area' in self.human_data:
info['area'] = self.human_data['area'][frame_start:frame_end]
else:
info['area'] = np.zeros((frame_num, 0))
return info
def prepare_data(self, idx: int):
"""Generate and transform data."""
info = self.prepare_raw_data(idx)
return self.pipeline(info)
def evaluate(self,
outputs: list,
res_folder: str,
metric: Optional[Union[str, List[str]]] = 'pa-mpjpe',
**kwargs: dict):
"""Evaluate 3D keypoint results.
Args:
outputs (list): results from model inference.
res_folder (str): path to store results.
metric (Optional[Union[str, List(str)]]):
the type of metric. Default: 'pa-mpjpe'
kwargs (dict): other arguments.
Returns:
dict:
A dict of all evaluation results.
"""
metrics = metric if isinstance(metric, list) else [metric]
for metric in metrics:
if metric not in self.ALLOWED_METRICS:
raise KeyError(f'metric {metric} is not supported')
res_file = os.path.join(res_folder, 'result_keypoints.json')
# for keeping correctness during multi-gpu test, we sort all results
res_dict = {}
# 'scores', 'labels', 'boxes', 'keypoints', 'pred_smpl_pose',
# 'pred_smpl_beta', 'pred_smpl_cam', 'pred_smpl_kp3d',
# 'gt_smpl_pose', 'gt_smpl_beta', 'gt_smpl_kp3d', 'gt_boxes',
# 'gt_keypoints', 'image_idx'
for out in outputs:
target_id = out['image_idx']
batch_size = len(out['pred_smpl_kp3d'])
for i in range(batch_size):
res_dict[int(target_id[i])] = dict(
keypoints=out['pred_smpl_kp3d'][i],
gt_poses=out['gt_smpl_pose'][i],
gt_betas=out['gt_smpl_beta'][i],
pred_poses=out['pred_smpl_pose'][i],
pred_betas=out['pred_smpl_beta'][i])
keypoints, gt_poses, gt_betas, pred_poses, pred_betas = \
[], [], [], [], []
# print(self.num_data)
for i in range(self.num_data):
keypoints.append(res_dict[i]['keypoints'])
gt_poses.append(res_dict[i]['gt_poses'])
gt_betas.append(res_dict[i]['gt_betas'])
pred_poses.append(res_dict[i]['pred_poses'])
pred_betas.append(res_dict[i]['pred_betas'])
res = dict(keypoints=keypoints,
gt_poses=gt_poses,
gt_betas=gt_betas,
pred_poses=pred_poses,
pred_betas=pred_betas)
# mmcv.dump(res, res_file)
name_value_tuples = []
for _metric in metrics:
if _metric == 'mpjpe':
_nv_tuples = self._report_mpjpe(res)
elif _metric == 'pa-mpjpe':
_nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe')
print(_nv_tuples)
elif _metric == '3dpck':
_nv_tuples = self._report_3d_pck(res)
elif _metric == 'pa-3dpck':
_nv_tuples = self._report_3d_pck(res, metric='pa-3dpck')
elif _metric == '3dauc':
_nv_tuples = self._report_3d_auc(res)
elif _metric == 'pa-3dauc':
_nv_tuples = self._report_3d_auc(res, metric='pa-3dauc')
elif _metric == 'pve':
_nv_tuples = self._report_pve(res)
elif _metric == 'ihmr':
_nv_tuples = self._report_ihmr(res)
else:
raise NotImplementedError
name_value_tuples.extend(_nv_tuples)
name_value = OrderedDict(name_value_tuples)
return name_value
@staticmethod
def _write_keypoint_results(keypoints: Any, res_file: str):
"""Write results into a json file."""
with open(res_file, 'w') as f:
json.dump(keypoints, f, sort_keys=True, indent=4)
def _parse_result(self, res, mode='keypoint', body_part=None):
"""Parse results."""
if mode == 'vertice':
# gt
gt_beta, gt_pose, gt_global_orient, gender = [], [], [], []
gt_smpl_dict = self.human_data['smpl']
for idx in range(self.num_data):
gt_beta.append(gt_smpl_dict['betas'][idx])
gt_pose.append(gt_smpl_dict['body_pose'][idx])
gt_global_orient.append(gt_smpl_dict['global_orient'][idx])
if self.human_data['meta']['gender'][idx] == 'm':
gender.append(0)
else:
gender.append(1)
gt_beta = torch.FloatTensor(gt_beta)
gt_pose = torch.FloatTensor(gt_pose).view(-1, 69)
gt_global_orient = torch.FloatTensor(gt_global_orient)
gender = torch.Tensor(gender)
gt_output = self.body_model(betas=gt_beta,
body_pose=gt_pose,
global_orient=gt_global_orient,
gender=gender)
gt_vertices = gt_output['vertices'].detach().cpu().numpy() * 1000.
gt_mask = np.ones(gt_vertices.shape[:-1])
# pred
pred_pose = torch.FloatTensor(res['pred_poses'])
pred_beta = torch.FloatTensor(res['pred_betas'])
pred_output = self.body_model(
betas=pred_beta[:, 0],
body_pose=pred_pose[:, 0, 1:],
global_orient=pred_pose[:, 0, 0].unsqueeze(1),
pose2rot=False)
pred_vertices = pred_output['vertices'].detach().cpu().numpy(
) * 1000.
assert len(pred_vertices) == self.num_data
return pred_vertices, gt_vertices, gt_mask
elif mode == 'keypoint':
pred_keypoints3d = res['keypoints']
assert len(pred_keypoints3d) == self.num_data
# (B, 17, 3)
pred_keypoints3d = np.array(pred_keypoints3d).reshape(
len(pred_keypoints3d), -1, 3)
# pred_keypoints3d,_ = convert_kps(
# pred_keypoints3d,
# src='smpl_54',
# dst='h36m',
# )
gt_smpl_pose = np.array(res['gt_poses'])
gt_body_pose = gt_smpl_pose[..., 1:, :]
gt_global_orient = gt_smpl_pose[..., 0, :]
gt_betas = np.array(res['gt_betas'])
gender = np.zeros([gt_betas.shape[0], gt_betas.shape[1]])
if self.dataset_name == 'pw3d':
# betas = []
# body_pose = []
# global_orient = []
# gender = []
# smpl_dict = self.human_data['smpl']
# for idx in range(self.num_data):
# betas.append(smpl_dict['betas'][idx])
# body_pose.append(smpl_dict['body_pose'][idx])
# global_orient.append(smpl_dict['global_orient'][idx])
# if self.human_data['meta']['gender'][idx] == 'm':
# gender.append(0)
# else:
# gender.append(1)
betas = torch.FloatTensor(gt_betas).view(-1, 10)
body_pose = torch.FloatTensor(gt_body_pose).view(-1, 69)
global_orient = torch.FloatTensor(gt_global_orient).view(-1, 3)
gender = torch.Tensor(gender).view(-1)
gt_output = self.body_model(betas=betas,
body_pose=body_pose,
global_orient=global_orient,
gender=gender)
gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
# gt_keypoints3d,_ = convert_kps(
# gt_keypoints3d,
# src='smpl_54',
# dst='h36m')
gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17))
elif self.dataset_name == 'h36m':
_, h36m_idxs, _ = get_mapping('human_data', 'h36m')
gt_keypoints3d = \
self.human_data['keypoints3d'][:, h36m_idxs, :3]
gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17))
elif self.dataset_name == 'humman':
betas = []
body_pose = []
global_orient = []
smpl_dict = self.human_data['smpl']
for idx in range(self.num_data):
betas.append(smpl_dict['betas'][idx])
body_pose.append(smpl_dict['body_pose'][idx])
global_orient.append(smpl_dict['global_orient'][idx])
betas = torch.FloatTensor(betas)
body_pose = torch.FloatTensor(body_pose).view(-1, 69)
global_orient = torch.FloatTensor(global_orient)
gt_output = self.body_model(betas=betas,
body_pose=body_pose,
global_orient=global_orient)
gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24))
else:
raise NotImplementedError()
# SMPL_49 only!
if gt_keypoints3d.shape[1] == 49:
assert pred_keypoints3d.shape[1] == 49
gt_keypoints3d = gt_keypoints3d[:, 25:, :]
pred_keypoints3d = pred_keypoints3d[:, 25:, :]
joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
# we only evaluate on 14 lsp joints
pred_pelvis = (pred_keypoints3d[:, 2] +
pred_keypoints3d[:, 3]) / 2
gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2
# H36M for testing!
elif gt_keypoints3d.shape[1] == 17:
assert pred_keypoints3d.shape[-2] == 17
H36M_TO_J17 = [
6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9
]
H36M_TO_J14 = H36M_TO_J17[:14]
joint_mapper = H36M_TO_J14
pred_pelvis = pred_keypoints3d[:, 0]
gt_pelvis = gt_keypoints3d[:, 0]
gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
# keypoint 24
elif gt_keypoints3d.shape[1] == 24:
assert pred_keypoints3d.shape[1] == 24
joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
# we only evaluate on 14 lsp joints
pred_pelvis = (pred_keypoints3d[:, 2] +
pred_keypoints3d[:, 3]) / 2
gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2
else:
pass
pred_keypoints3d = (pred_keypoints3d -
pred_pelvis[:, None, :]) * 1000
gt_keypoints3d = (gt_keypoints3d - gt_pelvis[:, None, :]) * 1000
gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper] > 0
return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask
def _report_mpjpe(self, res_file, metric='mpjpe', body_part=''):
"""Cauculate mean per joint position error (MPJPE) or its variants PA-
MPJPE.
Report mean per joint position error (MPJPE) and mean per joint
position error after rigid alignment (PA-MPJPE)
"""
pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
self._parse_result(res_file, mode='keypoint', body_part=body_part)
err_name = metric.upper()
if body_part != '':
err_name = body_part.upper() + ' ' + err_name
if metric == 'mpjpe':
alignment = 'none'
elif metric == 'pa-mpjpe':
alignment = 'procrustes'
else:
raise ValueError(f'Invalid metric: {metric}')
error = keypoint_mpjpe(pred_keypoints3d, gt_keypoints3d,
gt_keypoints3d_mask, alignment)
info_str = [(err_name, error)]
return info_str
def _report_3d_pck(self, res_file, metric='3dpck'):
"""Cauculate Percentage of Correct Keypoints (3DPCK) w. or w/o
Procrustes alignment.
Args:
keypoint_results (list): Keypoint predictions. See
'Body3DMpiInf3dhpDataset.evaluate' for details.
metric (str): Specify mpjpe variants. Supported options are:
- ``'3dpck'``: Standard 3DPCK.
- ``'pa-3dpck'``:
3DPCK after aligning prediction to groundtruth
via a rigid transformation (scale, rotation and
translation).
"""
pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
self._parse_result(res_file)
err_name = metric.upper()
if metric == '3dpck':
alignment = 'none'
elif metric == 'pa-3dpck':
alignment = 'procrustes'
else:
raise ValueError(f'Invalid metric: {metric}')
error = keypoint_3d_pck(pred_keypoints3d, gt_keypoints3d,
gt_keypoints3d_mask, alignment)
name_value_tuples = [(err_name, error)]
return name_value_tuples
def _report_3d_auc(self, res_file, metric='3dauc'):
"""Cauculate the Area Under the Curve (AUC) computed for a range of
3DPCK thresholds.
Args:
keypoint_results (list): Keypoint predictions. See
'Body3DMpiInf3dhpDataset.evaluate' for details.
metric (str): Specify mpjpe variants. Supported options are:
- ``'3dauc'``: Standard 3DAUC.
- ``'pa-3dauc'``: 3DAUC after aligning prediction to
groundtruth via a rigid transformation (scale, rotation and
translation).
"""
pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
self._parse_result(res_file)
err_name = metric.upper()
if metric == '3dauc':
alignment = 'none'
elif metric == 'pa-3dauc':
alignment = 'procrustes'
else:
raise ValueError(f'Invalid metric: {metric}')
error = keypoint_3d_auc(pred_keypoints3d, gt_keypoints3d,
gt_keypoints3d_mask, alignment)
name_value_tuples = [(err_name, error)]
return name_value_tuples
def _report_pve(self, res_file, metric='pve', body_part=''):
"""Cauculate per vertex error."""
pred_verts, gt_verts, _ = \
self._parse_result(res_file, mode='vertice', body_part=body_part)
err_name = metric.upper()
if body_part != '':
err_name = body_part.upper() + ' ' + err_name
if metric == 'pve':
alignment = 'none'
elif metric == 'pa-pve':
alignment = 'procrustes'
else:
raise ValueError(f'Invalid metric: {metric}')
error = vertice_pve(pred_verts, gt_verts, alignment)
return [(err_name, error)]
def _report_ihmr(self, res_file):
"""Calculate IHMR metric.
https://arxiv.org/abs/2203.16427
"""
pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
self._parse_result(res_file, mode='keypoint')
pred_verts, gt_verts, _ = \
self._parse_result(res_file, mode='vertice')
from detrsmpl.utils.geometry import rot6d_to_rotmat
mean_param_path = 'data/body_models/smpl_mean_params.npz'
mean_params = np.load(mean_param_path)
mean_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
mean_shape = torch.from_numpy(
mean_params['shape'][:].astype('float32')).unsqueeze(0)
mean_pose = rot6d_to_rotmat(mean_pose).view(1, 24, 3, 3)
mean_output = self.body_model(betas=mean_shape,
body_pose=mean_pose[:, 1:],
global_orient=mean_pose[:, :1],
pose2rot=False)
mean_verts = mean_output['vertices'].detach().cpu().numpy() * 1000.
dis = (gt_verts - mean_verts) * (gt_verts - mean_verts)
dis = np.sqrt(dis.sum(axis=-1)).mean(axis=-1)
# from the most remote one to the nearest one
idx_order = np.argsort(dis)[::-1]
num_data = idx_order.shape[0]
def report_ihmr_idx(idx):
mpvpe = vertice_pve(pred_verts[idx], gt_verts[idx])
mpjpe = keypoint_mpjpe(pred_keypoints3d[idx], gt_keypoints3d[idx],
gt_keypoints3d_mask[idx], 'none')
pampjpe = keypoint_mpjpe(pred_keypoints3d[idx],
gt_keypoints3d[idx],
gt_keypoints3d_mask[idx], 'procrustes')
return (mpvpe, mpjpe, pampjpe)
def report_ihmr_tail(percentage):
cur_data = int(num_data * percentage / 100.0)
idx = idx_order[:cur_data]
mpvpe, mpjpe, pampjpe = report_ihmr_idx(idx)
res_mpvpe = ('bMPVPE Tail ' + str(percentage) + '%', mpvpe)
res_mpjpe = ('bMPJPE Tail ' + str(percentage) + '%', mpjpe)
res_pampjpe = ('bPA-MPJPE Tail ' + str(percentage) + '%', pampjpe)
return [res_mpvpe, res_mpjpe, res_pampjpe]
def report_ihmr_all(num_bin):
num_per_bin = np.array([0 for _ in range(num_bin)
]).astype(np.float32)
sum_mpvpe = np.array([0
for _ in range(num_bin)]).astype(np.float32)
sum_mpjpe = np.array([0
for _ in range(num_bin)]).astype(np.float32)
sum_pampjpe = np.array([0 for _ in range(num_bin)
]).astype(np.float32)
max_dis = dis[idx_order[0]]
min_dis = dis[idx_order[-1]]
delta = (max_dis - min_dis) / num_bin
for i in range(num_data):
idx = int((dis[i] - min_dis) / delta - 0.001)
res_mpvpe, res_mpjpe, res_pampjpe = report_ihmr_idx([i])
num_per_bin[idx] += 1
sum_mpvpe[idx] += res_mpvpe
sum_mpjpe[idx] += res_mpjpe
sum_pampjpe[idx] += res_pampjpe
for i in range(num_bin):
if num_per_bin[i] > 0:
sum_mpvpe[i] = sum_mpvpe[i] / num_per_bin[i]
sum_mpjpe[i] = sum_mpjpe[i] / num_per_bin[i]
sum_pampjpe[i] = sum_pampjpe[i] / num_per_bin[i]
valid_idx = np.where(num_per_bin > 0)
res_mpvpe = ('bMPVPE All', sum_mpvpe[valid_idx].mean())
res_mpjpe = ('bMPJPE All', sum_mpjpe[valid_idx].mean())
res_pampjpe = ('bPA-MPJPE All', sum_pampjpe[valid_idx].mean())
return [res_mpvpe, res_mpjpe, res_pampjpe]
metrics = []
metrics.extend(report_ihmr_all(num_bin=100))
metrics.extend(report_ihmr_tail(percentage=10))
metrics.extend(report_ihmr_tail(percentage=5))
return metrics