extenew / extensions /DWPose /mmpose /demo /body3d_pose_lifter_demo.py
dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
raw
history blame
18.4 kB
# Copyright (c) OpenMMLab. All rights reserved.
import mimetypes
import os
import time
from argparse import ArgumentParser
from functools import partial
import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np
from mmengine.structures import InstanceData
from mmpose.apis import (_track_by_iou, _track_by_oks, collect_multi_frames,
convert_keypoint_definition, extract_pose_sequence,
inference_pose_lifter_model, inference_topdown,
init_model)
from mmpose.models.pose_estimators import PoseLifter
from mmpose.models.pose_estimators.topdown import TopdownPoseEstimator
from mmpose.registry import VISUALIZERS
from mmpose.structures import (PoseDataSample, merge_data_samples,
split_instances)
from mmpose.utils import adapt_mmdet_pipeline
try:
from mmdet.apis import inference_detector, init_detector
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
def parse_args():
parser = ArgumentParser()
parser.add_argument('det_config', help='Config file for detection')
parser.add_argument('det_checkpoint', help='Checkpoint file for detection')
parser.add_argument(
'pose_estimator_config',
type=str,
default=None,
help='Config file for the 1st stage 2D pose estimator')
parser.add_argument(
'pose_estimator_checkpoint',
type=str,
default=None,
help='Checkpoint file for the 1st stage 2D pose estimator')
parser.add_argument(
'pose_lifter_config',
help='Config file for the 2nd stage pose lifter model')
parser.add_argument(
'pose_lifter_checkpoint',
help='Checkpoint file for the 2nd stage pose lifter model')
parser.add_argument('--input', type=str, default='', help='Video path')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='Whether to show visualizations')
parser.add_argument(
'--rebase-keypoint-height',
action='store_true',
help='Rebase the predicted 3D pose so its lowest keypoint has a '
'height of 0 (landing on the ground). This is useful for '
'visualization when the model do not predict the global position '
'of the 3D pose.')
parser.add_argument(
'--norm-pose-2d',
action='store_true',
help='Scale the bbox (along with the 2D pose) to the average bbox '
'scale of the dataset, and move the bbox (along with the 2D pose) to '
'the average bbox center of the dataset. This is useful when bbox '
'is small, especially in multi-person scenarios.')
parser.add_argument(
'--num-instances',
type=int,
default=-1,
help='The number of 3D poses to be visualized in every frame. If '
'less than 0, it will be set to the number of pose results in the '
'first frame.')
parser.add_argument(
'--output-root',
type=str,
default='',
help='Root of the output video file. '
'Default not saving the visualization video.')
parser.add_argument(
'--save-predictions',
action='store_true',
default=False,
help='whether to save predicted results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--det-cat-id',
type=int,
default=0,
help='Category id for bounding box detection model')
parser.add_argument(
'--bbox-thr',
type=float,
default=0.9,
help='Bounding box score threshold')
parser.add_argument('--kpt-thr', type=float, default=0.3)
parser.add_argument(
'--use-oks-tracking', action='store_true', help='Using OKS tracking')
parser.add_argument(
'--tracking-thr', type=float, default=0.3, help='Tracking threshold')
parser.add_argument(
'--show-interval', type=int, default=0, help='Sleep seconds per frame')
parser.add_argument(
'--thickness',
type=int,
default=1,
help='Link thickness for visualization')
parser.add_argument(
'--radius',
type=int,
default=3,
help='Keypoint radius for visualization')
parser.add_argument(
'--use-multi-frames',
action='store_true',
default=False,
help='whether to use multi frames for inference in the 2D pose'
'detection stage. Default: False.')
args = parser.parse_args()
return args
def get_area(results):
for i, data_sample in enumerate(results):
pred_instance = data_sample.pred_instances.cpu().numpy()
if 'bboxes' in pred_instance:
bboxes = pred_instance.bboxes
results[i].pred_instances.set_field(
np.array([(bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
for bbox in bboxes]), 'areas')
else:
keypoints = pred_instance.keypoints
areas, bboxes = [], []
for keypoint in keypoints:
xmin = np.min(keypoint[:, 0][keypoint[:, 0] > 0], initial=1e10)
xmax = np.max(keypoint[:, 0])
ymin = np.min(keypoint[:, 1][keypoint[:, 1] > 0], initial=1e10)
ymax = np.max(keypoint[:, 1])
areas.append((xmax - xmin) * (ymax - ymin))
bboxes.append([xmin, ymin, xmax, ymax])
results[i].pred_instances.areas = np.array(areas)
results[i].pred_instances.bboxes = np.array(bboxes)
return results
def get_pose_est_results(args, pose_estimator, frame, bboxes,
pose_est_results_last, next_id, pose_lift_dataset):
pose_det_dataset = pose_estimator.cfg.test_dataloader.dataset
# make person results for current image
pose_est_results = inference_topdown(pose_estimator, frame, bboxes)
pose_est_results = get_area(pose_est_results)
if args.use_oks_tracking:
_track = partial(_track_by_oks)
else:
_track = _track_by_iou
for i, result in enumerate(pose_est_results):
track_id, pose_est_results_last, match_result = _track(
result, pose_est_results_last, args.tracking_thr)
if track_id == -1:
pred_instances = result.pred_instances.cpu().numpy()
keypoints = pred_instances.keypoints
if np.count_nonzero(keypoints[:, :, 1]) >= 3:
pose_est_results[i].set_field(next_id, 'track_id')
next_id += 1
else:
# If the number of keypoints detected is small,
# delete that person instance.
keypoints[:, :, 1] = -10
pose_est_results[i].pred_instances.set_field(
keypoints, 'keypoints')
bboxes = pred_instances.bboxes * 0
pose_est_results[i].pred_instances.set_field(bboxes, 'bboxes')
pose_est_results[i].set_field(-1, 'track_id')
pose_est_results[i].set_field(pred_instances, 'pred_instances')
else:
pose_est_results[i].set_field(track_id, 'track_id')
del match_result
pose_est_results_converted = []
for pose_est_result in pose_est_results:
pose_est_result_converted = PoseDataSample()
gt_instances = InstanceData()
pred_instances = InstanceData()
for k in pose_est_result.gt_instances.keys():
gt_instances.set_field(pose_est_result.gt_instances[k], k)
for k in pose_est_result.pred_instances.keys():
pred_instances.set_field(pose_est_result.pred_instances[k], k)
pose_est_result_converted.gt_instances = gt_instances
pose_est_result_converted.pred_instances = pred_instances
pose_est_result_converted.track_id = pose_est_result.track_id
keypoints = convert_keypoint_definition(pred_instances.keypoints,
pose_det_dataset['type'],
pose_lift_dataset['type'])
pose_est_result_converted.pred_instances.keypoints = keypoints
pose_est_results_converted.append(pose_est_result_converted)
return pose_est_results, pose_est_results_converted, next_id
def get_pose_lift_results(args, visualizer, pose_lifter, pose_est_results_list,
frame, frame_idx, pose_est_results):
pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset
# extract and pad input pose2d sequence
pose_seq_2d = extract_pose_sequence(
pose_est_results_list,
frame_idx=frame_idx,
causal=pose_lift_dataset.get('causal', False),
seq_len=pose_lift_dataset.get('seq_len', 1),
step=pose_lift_dataset.get('seq_step', 1))
# 2D-to-3D pose lifting
width, height = frame.shape[:2]
pose_lift_results = inference_pose_lifter_model(
pose_lifter,
pose_seq_2d,
image_size=(width, height),
norm_pose_2d=args.norm_pose_2d)
# Pose processing
for idx, pose_lift_res in enumerate(pose_lift_results):
pose_lift_res.track_id = pose_est_results[idx].get('track_id', 1e4)
pred_instances = pose_lift_res.pred_instances
keypoints = pred_instances.keypoints
keypoint_scores = pred_instances.keypoint_scores
if keypoint_scores.ndim == 3:
keypoint_scores = np.squeeze(keypoint_scores, axis=1)
pose_lift_results[
idx].pred_instances.keypoint_scores = keypoint_scores
if keypoints.ndim == 4:
keypoints = np.squeeze(keypoints, axis=1)
keypoints = keypoints[..., [0, 2, 1]]
keypoints[..., 0] = -keypoints[..., 0]
keypoints[..., 2] = -keypoints[..., 2]
# rebase height (z-axis)
if args.rebase_keypoint_height:
keypoints[..., 2] -= np.min(
keypoints[..., 2], axis=-1, keepdims=True)
pose_lift_results[idx].pred_instances.keypoints = keypoints
pose_lift_results = sorted(
pose_lift_results, key=lambda x: x.get('track_id', 1e4))
pred_3d_data_samples = merge_data_samples(pose_lift_results)
det_data_sample = merge_data_samples(pose_est_results)
if args.num_instances < 0:
args.num_instances = len(pose_lift_results)
# Visualization
if visualizer is not None:
visualizer.add_datasample(
'result',
frame,
data_sample=pred_3d_data_samples,
det_data_sample=det_data_sample,
draw_gt=False,
show=args.show,
draw_bbox=True,
kpt_thr=args.kpt_thr,
num_instances=args.num_instances,
wait_time=args.show_interval)
return pred_3d_data_samples.get('pred_instances', None)
def get_bbox(args, detector, frame):
det_result = inference_detector(detector, frame)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = pred_instance.bboxes
bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
pred_instance.scores > args.bbox_thr)]
return bboxes
def main():
assert has_mmdet, 'Please install mmdet to run the demo.'
args = parse_args()
assert args.show or (args.output_root != '')
assert args.input != ''
assert args.det_config is not None
assert args.det_checkpoint is not None
detector = init_detector(
args.det_config, args.det_checkpoint, device=args.device.lower())
detector.cfg = adapt_mmdet_pipeline(detector.cfg)
pose_estimator = init_model(
args.pose_estimator_config,
args.pose_estimator_checkpoint,
device=args.device.lower())
assert isinstance(pose_estimator, TopdownPoseEstimator), 'Only "TopDown"' \
'model is supported for the 1st stage (2D pose detection)'
det_kpt_color = pose_estimator.dataset_meta.get('keypoint_colors', None)
det_dataset_skeleton = pose_estimator.dataset_meta.get(
'skeleton_links', None)
det_dataset_link_color = pose_estimator.dataset_meta.get(
'skeleton_link_colors', None)
# frame index offsets for inference, used in multi-frame inference setting
if args.use_multi_frames:
assert 'frame_indices' in pose_estimator.cfg.test_dataloader.dataset
indices = pose_estimator.cfg.test_dataloader.dataset[
'frame_indices_test']
pose_lifter = init_model(
args.pose_lifter_config,
args.pose_lifter_checkpoint,
device=args.device.lower())
assert isinstance(pose_lifter, PoseLifter), \
'Only "PoseLifter" model is supported for the 2nd stage ' \
'(2D-to-3D lifting)'
pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset
pose_lifter.cfg.visualizer.radius = args.radius
pose_lifter.cfg.visualizer.line_width = args.thickness
pose_lifter.cfg.visualizer.det_kpt_color = det_kpt_color
pose_lifter.cfg.visualizer.det_dataset_skeleton = det_dataset_skeleton
pose_lifter.cfg.visualizer.det_dataset_link_color = det_dataset_link_color
visualizer = VISUALIZERS.build(pose_lifter.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint
visualizer.set_dataset_meta(pose_lifter.dataset_meta)
if args.input == 'webcam':
input_type = 'webcam'
else:
input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
if args.output_root == '':
save_output = False
else:
mmengine.mkdir_or_exist(args.output_root)
output_file = os.path.join(args.output_root,
os.path.basename(args.input))
if args.input == 'webcam':
output_file += '.mp4'
save_output = True
if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'
if save_output:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
pose_est_results_list = []
pred_instances_list = []
if input_type == 'image':
frame = mmcv.imread(args.input, channel_order='rgb')
# First stage: 2D pose detection
bboxes = get_bbox(args, detector, frame)
pose_est_results, pose_est_results_converted, _ = get_pose_est_results(
args, pose_estimator, frame, bboxes, [], 0, pose_lift_dataset)
pose_est_results_list.append(pose_est_results_converted.copy())
pred_3d_pred = get_pose_lift_results(args, visualizer, pose_lifter,
pose_est_results_list, frame, 0,
pose_est_results)
if args.save_predictions:
# save prediction results
pred_instances_list = split_instances(pred_3d_pred)
if save_output:
frame_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(frame_vis), output_file)
elif input_type in ['webcam', 'video']:
next_id = 0
pose_est_results_converted = []
if args.input == 'webcam':
video = cv2.VideoCapture(0)
else:
video = cv2.VideoCapture(args.input)
(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
if int(major_ver) < 3:
fps = video.get(cv2.cv.CV_CAP_PROP_FPS)
else:
fps = video.get(cv2.CAP_PROP_FPS)
video_writer = None
frame_idx = 0
while video.isOpened():
success, frame = video.read()
frame_idx += 1
if not success:
break
pose_est_results_last = pose_est_results_converted
# First stage: 2D pose detection
if args.use_multi_frames:
frames = collect_multi_frames(video, frame_idx, indices,
args.online)
# make person results for current image
bboxes = get_bbox(args, detector, frame)
pose_est_results, pose_est_results_converted, next_id = get_pose_est_results( # noqa: E501
args, pose_estimator,
frames if args.use_multi_frames else frame, bboxes,
pose_est_results_last, next_id, pose_lift_dataset)
pose_est_results_list.append(pose_est_results_converted.copy())
# Second stage: Pose lifting
pred_3d_pred = get_pose_lift_results(args, visualizer, pose_lifter,
pose_est_results_list,
mmcv.bgr2rgb(frame),
frame_idx, pose_est_results)
if args.save_predictions:
# save prediction results
pred_instances_list.append(
dict(
frame_id=frame_idx,
instances=split_instances(pred_3d_pred)))
if save_output:
frame_vis = visualizer.get_image()
if video_writer is None:
# the size of the image with visualization may vary
# depending on the presence of heatmaps
video_writer = cv2.VideoWriter(output_file, fourcc, fps,
(frame_vis.shape[1],
frame_vis.shape[0]))
video_writer.write(mmcv.rgb2bgr(frame_vis))
# press ESC to exit
if cv2.waitKey(5) & 0xFF == 27:
break
time.sleep(args.show_interval)
video.release()
if video_writer:
video_writer.release()
else:
args.save_predictions = False
raise ValueError(
f'file {os.path.basename(args.input)} has invalid format.')
if args.save_predictions:
with open(args.pred_save_path, 'w') as f:
json.dump(
dict(
meta_info=pose_lifter.dataset_meta,
instance_info=pred_instances_list),
f,
indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')
if __name__ == '__main__':
main()