AiOS / detrsmpl /utils /keypoint_utils.py
ttxskk
update
d7e58f0
raw
history blame
2.55 kB
from typing import Optional, Tuple, Union
import numpy as np
from detrsmpl.core.conventions.keypoints_mapping import KEYPOINTS_FACTORY
from detrsmpl.core.conventions.keypoints_mapping.human_data import (
HUMAN_DATA_LIMBS_INDEX,
HUMAN_DATA_PALETTE,
)
def search_limbs(
data_source: str,
mask: Optional[Union[np.ndarray, tuple, list]] = None,
keypoints_factory: dict = KEYPOINTS_FACTORY) -> Tuple[dict, dict]:
"""Search the corresponding limbs following the basis human_data limbs. The
mask could mask out the incorrect keypoints.
Args:
data_source (str): data source type.
mask (Optional[Union[np.ndarray, tuple, list]], optional):
refer to keypoints_mapping. Defaults to None.
keypoints_factory (dict, optional): Dict of all the conventions.
Defaults to KEYPOINTS_FACTORY.
Returns:
Tuple[dict, dict]: (limbs_target, limbs_palette).
"""
limbs_source = HUMAN_DATA_LIMBS_INDEX
limbs_palette = HUMAN_DATA_PALETTE
keypoints_source = keypoints_factory['human_data']
keypoints_target = keypoints_factory[data_source]
limbs_target = {}
for k, part_limbs in limbs_source.items():
limbs_target[k] = []
for limb in part_limbs:
flag = False
if (keypoints_source[limb[0]]
in keypoints_target) and (keypoints_source[limb[1]]
in keypoints_target):
if mask is not None:
if mask[keypoints_target.index(keypoints_source[
limb[0]])] != 0 and mask[keypoints_target.index(
keypoints_source[limb[1]])] != 0:
flag = True
else:
flag = True
if flag:
limbs_target.setdefault(k, []).append([
keypoints_target.index(keypoints_source[limb[0]]),
keypoints_target.index(keypoints_source[limb[1]])
])
if k in limbs_target:
if k == 'body':
np.random.seed(0)
limbs_palette[k] = np.random.randint(0,
high=255,
size=(len(
limbs_target[k]), 3))
else:
limbs_palette[k] = np.array(limbs_palette[k])
return limbs_target, limbs_palette