dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
raw
history blame
6.86 kB
# Copyright (c) OpenMMLab. All rights reserved.
import mimetypes
import os
import time
from argparse import ArgumentParser
import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np
from mmpose.apis import inference_bottomup, init_model
from mmpose.registry import VISUALIZERS
from mmpose.structures import split_instances
def process_one_image(args,
img,
pose_estimator,
visualizer=None,
show_interval=0):
"""Visualize predicted keypoints (and heatmaps) of one image."""
# inference a single image
batch_results = inference_bottomup(pose_estimator, img)
results = batch_results[0]
# show the results
if isinstance(img, str):
img = mmcv.imread(img, channel_order='rgb')
elif isinstance(img, np.ndarray):
img = mmcv.bgr2rgb(img)
if visualizer is not None:
visualizer.add_datasample(
'result',
img,
data_sample=results,
draw_gt=False,
draw_bbox=False,
draw_heatmap=args.draw_heatmap,
show_kpt_idx=args.show_kpt_idx,
show=args.show,
wait_time=show_interval,
kpt_thr=args.kpt_thr)
return results.pred_instances
def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--input', type=str, default='', help='Image/Video file')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='whether to show img')
parser.add_argument(
'--output-root',
type=str,
default='',
help='root of the output img file. '
'Default not saving the visualization images.')
parser.add_argument(
'--save-predictions',
action='store_true',
default=False,
help='whether to save predicted results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--draw-heatmap',
action='store_true',
help='Visualize the predicted heatmap')
parser.add_argument(
'--show-kpt-idx',
action='store_true',
default=False,
help='Whether to show the index of keypoints')
parser.add_argument(
'--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
parser.add_argument(
'--radius',
type=int,
default=3,
help='Keypoint radius for visualization')
parser.add_argument(
'--thickness',
type=int,
default=1,
help='Link thickness for visualization')
parser.add_argument(
'--show-interval', type=int, default=0, help='Sleep seconds per frame')
args = parser.parse_args()
return args
def main():
args = parse_args()
assert args.show or (args.output_root != '')
assert args.input != ''
output_file = None
if args.output_root:
mmengine.mkdir_or_exist(args.output_root)
output_file = os.path.join(args.output_root,
os.path.basename(args.input))
if args.input == 'webcam':
output_file += '.mp4'
if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'
# build the model from a config file and a checkpoint file
if args.draw_heatmap:
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))
else:
cfg_options = None
model = init_model(
args.config,
args.checkpoint,
device=args.device,
cfg_options=cfg_options)
# build visualizer
model.cfg.visualizer.radius = args.radius
model.cfg.visualizer.line_width = args.thickness
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.set_dataset_meta(model.dataset_meta)
if args.input == 'webcam':
input_type = 'webcam'
else:
input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
if input_type == 'image':
# inference
pred_instances = process_one_image(
args, args.input, model, visualizer, show_interval=0)
if args.save_predictions:
pred_instances_list = split_instances(pred_instances)
if output_file:
img_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)
elif input_type in ['webcam', 'video']:
if args.input == 'webcam':
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(args.input)
video_writer = None
pred_instances_list = []
frame_idx = 0
while cap.isOpened():
success, frame = cap.read()
frame_idx += 1
if not success:
break
pred_instances = process_one_image(args, frame, model, visualizer,
0.001)
if args.save_predictions:
# save prediction results
pred_instances_list.append(
dict(
frame_id=frame_idx,
instances=split_instances(pred_instances)))
# output videos
if output_file:
frame_vis = visualizer.get_image()
if video_writer is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# the size of the image with visualization may vary
# depending on the presence of heatmaps
video_writer = cv2.VideoWriter(
output_file,
fourcc,
25, # saved fps
(frame_vis.shape[1], frame_vis.shape[0]))
video_writer.write(mmcv.rgb2bgr(frame_vis))
# press ESC to exit
if cv2.waitKey(5) & 0xFF == 27:
break
time.sleep(args.show_interval)
if video_writer:
video_writer.release()
cap.release()
else:
args.save_predictions = False
raise ValueError(
f'file {os.path.basename(args.input)} has invalid format.')
if args.save_predictions:
with open(args.pred_save_path, 'w') as f:
json.dump(
dict(
meta_info=model.dataset_meta,
instance_info=pred_instances_list),
f,
indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')
if __name__ == '__main__':
main()