AiOS / detrsmpl /core /visualization /visualize_keypoints3d.py
ttxskk
update
d7e58f0
raw
history blame
9.66 kB
import warnings
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import detrsmpl.core.conventions.keypoints_mapping as keypoints_mapping
from detrsmpl.core.renderer.matplotlib3d_renderer import Axes3dJointsRenderer
from detrsmpl.utils.demo_utils import get_different_colors
from detrsmpl.utils.keypoint_utils import search_limbs
from detrsmpl.utils.path_utils import prepare_output_path
def _norm_pose(pose_numpy: np.ndarray, min_value: Union[float, int],
max_value: Union[float, int], mask: Union[np.ndarray, list]):
"""Normalize the poses and make the center close to axis center."""
assert max_value > min_value
pose_np_normed = pose_numpy.copy()
if not mask:
mask = list(range(pose_numpy.shape[-2]))
axis_num = 3
axis_stat = np.zeros(shape=[axis_num, 4])
for axis_index in range(axis_num):
axis_data = pose_np_normed[..., mask, axis_index]
axis_min = np.min(axis_data)
axis_max = np.max(axis_data)
axis_mid = (axis_min + axis_max) / 2.0
axis_span = axis_max - axis_min
axis_stat[axis_index] = np.asarray(
(axis_min, axis_max, axis_mid, axis_span))
target_mid = (max_value + min_value) / 2.0
max_span = np.max(axis_stat[:, 3])
target_span = max_value - min_value
for axis_index in range(axis_num):
pose_np_normed[..., axis_index] = \
pose_np_normed[..., axis_index] - \
axis_stat[axis_index, 2]
pose_np_normed = pose_np_normed / max_span * target_span
pose_np_normed = pose_np_normed + target_mid
return pose_np_normed
def visualize_kp3d(
kp3d: np.ndarray,
output_path: Optional[str] = None,
limbs: Optional[Union[np.ndarray, List[int]]] = None,
palette: Optional[Iterable[int]] = None,
data_source: str = 'coco',
mask: Optional[Union[list, tuple, np.ndarray]] = None,
start: int = 0,
end: Optional[int] = None,
resolution: Union[list, Tuple[int, int]] = (1024, 1024),
fps: Union[float, int] = 30,
frame_names: Optional[Union[List[str], str]] = None,
orbit_speed: Union[float, int] = 0.5,
value_range: Union[Tuple[int, int], list] = (-100, 100),
pop_parts: Iterable[str] = (),
disable_limbs: bool = False,
return_array: Optional[bool] = None,
convention: str = 'opencv',
keypoints_factory: dict = keypoints_mapping.KEYPOINTS_FACTORY,
) -> Union[None, np.ndarray]:
"""Visualize 3d keypoints to a video with matplotlib. Support multi person
and specified limb connections.
Args:
kp3d (np.ndarray): shape could be (f * J * 4/3/2) or
(f * num_person * J * 4/3/2)
output_path (str): output video path image folder.
limbs (Optional[Union[np.ndarray, List[int]]], optional):
if not specified, the limbs will be searched by search_limbs,
this option is for free skeletons like BVH file.
Defaults to None.
palette (Iterable, optional): specified palette, three int represents
(B, G, R). Should be tuple or list.
Defaults to None.
data_source (str, optional): data source type. Defaults to 'coco'.
choose in ['coco', 'smplx', 'smpl', 'coco_wholebody',
'mpi_inf_3dhp', 'mpi_inf_3dhp_test', 'h36m', 'pw3d', 'mpii']
mask (Optional[Union[list, tuple, np.ndarray]], optional):
mask to mask out the incorrect points. Defaults to None.
start (int, optional): start frame index. Defaults to 0.
end (int, optional): end frame index.
Could be positive int or negative int or None.
None represents include all the frames.
Defaults to None.
resolution (Union[list, Tuple[int, int]], optional):
(width, height) of the output video
will be the same size as the original images if not specified.
Defaults to None.
fps (Union[float, int], optional): fps. Defaults to 30.
frame_names (Optional[Union[List[str], str]], optional): List(should be
the same as frame numbers) or single string or string format
(like 'frame%06d')for frame title, no title if None.
Defaults to None.
orbit_speed (Union[float, int], optional): orbit speed of camera.
Defaults to 0.5.
value_range (Union[Tuple[int, int], list], optional):
range of axis value. Defaults to (-100, 100).
pop_parts (Iterable[str], optional): The body part names you do not
want to visualize. Choose in ['left_eye','right_eye', 'nose',
'mouth', 'face', 'left_hand', 'right_hand']Defaults to [].
disable_limbs (bool, optional): whether need to disable drawing limbs.
Defaults to False.
return_array (bool, optional): Whether to return images as opencv array
.If None, an array will be returned when frame number is below 100.
Defaults to None.
keypoints_factory (dict, optional): Dict of all the conventions.
Defaults to KEYPOINTS_FACTORY.
Raises:
TypeError: check the type of input keypoints.
FileNotFoundError: check the output video path.
Returns:
Union[None, np.ndarray].
"""
# check input shape
if not isinstance(kp3d, np.ndarray):
raise TypeError(
f'Input type is {type(kp3d)}, which should be numpy.ndarray.')
kp3d = kp3d.copy()
if kp3d.shape[-1] == 2:
kp3d = np.concatenate([kp3d, np.zeros_like(kp3d)[..., 0:1]], axis=-1)
warnings.warn(
'The input array is 2-Dimensional coordinates, will concatenate ' +
f'zeros to the last axis. The new array shape: {kp3d.shape}')
elif kp3d.shape[-1] >= 4:
kp3d = kp3d[..., :3]
warnings.warn(
'The input array has more than 3-Dimensional coordinates, will ' +
'keep only the first 3-Dimensions of the last axis. The new ' +
f'array shape: {kp3d.shape}')
if kp3d.ndim == 3:
kp3d = np.expand_dims(kp3d, 1)
num_frames = kp3d.shape[0]
assert kp3d.ndim == 4
assert kp3d.shape[-1] == 3
if return_array is None:
if num_frames > 100:
return_array = False
else:
return_array = True
# check data_source & mask
if data_source not in keypoints_factory:
raise ValueError('Wrong data_source. Should choose in' +
f'{list(keypoints_factory.keys())}')
if mask is not None:
if not isinstance(mask, np.ndarray):
mask = np.array(mask).reshape(-1)
assert mask.shape == (
len(keypoints_factory[data_source]),
), f'mask length should fit with keypoints number \
{len(keypoints_factory[data_source])}'
# check the output path
if output_path is not None:
prepare_output_path(output_path,
path_type='auto',
tag='output video',
allowed_suffix=['.mp4', '.gif', ''])
# slice the frames
end = num_frames if end is None else end
kp3d = kp3d[start:end]
# norm the coordinates
if value_range is not None:
# norm pose location to value_range (70% value range)
mask_index = np.where(np.array(mask) > 0) if mask is not None else None
margin_width = abs(value_range[1] - value_range[1]) * 0.15
pose_np_normed = _norm_pose(kp3d, value_range[0] + margin_width,
value_range[1] - margin_width, mask_index)
input_pose_np = pose_np_normed
else:
input_pose_np = kp3d
# determine the limb connections and palettes
if limbs is not None:
limbs_target, limbs_palette = {
'body': limbs.tolist() if isinstance(limbs, np.ndarray) else limbs
}, get_different_colors(len(limbs))
else:
limbs_target, limbs_palette = search_limbs(data_source=data_source,
mask=mask)
if palette is not None:
limbs_palette = np.array(palette, dtype=np.uint8)[None]
# check and pop the pop_parts
assert set(pop_parts).issubset(
keypoints_mapping.human_data.HUMAN_DATA_PALETTE.keys(
)), f'wrong part_names in pop_parts, could only \
choose in{set(keypoints_mapping.human_data.HUMAN_DATA_PALETTE.keys())}'
for part_name in pop_parts:
if part_name in limbs_target:
limbs_target.pop(part_name)
# initialize renderer and start render
renderer = Axes3dJointsRenderer()
renderer.init_camera(cam_hori_speed=orbit_speed, cam_elev_speed=0.2)
renderer.set_connections(limbs_target, limbs_palette)
if isinstance(frame_names, str):
if '%' in frame_names:
frame_names = [
frame_names % index for index in range(input_pose_np.shape[0])
]
else:
frame_names = [frame_names] * input_pose_np.shape[0]
image_array = renderer.render_kp3d_to_video(input_pose_np,
output_path,
convention,
fps=fps,
resolution=resolution,
visual_range=value_range,
frame_names=frame_names,
disable_limbs=disable_limbs,
return_array=return_array)
return image_array