File size: 3,088 Bytes
f4a41d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
from mmcv.image import imread
from mmpose.apis import inference_topdown, init_model
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--out-file', default=None, help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--draw-heatmap',
action='store_true',
help='Visualize the predicted heatmap')
parser.add_argument(
'--show-kpt-idx',
action='store_true',
default=False,
help='Whether to show the index of keypoints')
parser.add_argument(
'--skeleton-style',
default='mmpose',
type=str,
choices=['mmpose', 'openpose'],
help='Skeleton style selection')
parser.add_argument(
'--kpt-thr',
type=float,
default=0.3,
help='Visualizing keypoint thresholds')
parser.add_argument(
'--radius',
type=int,
default=3,
help='Keypoint radius for visualization')
parser.add_argument(
'--thickness',
type=int,
default=1,
help='Link thickness for visualization')
parser.add_argument(
'--alpha', type=float, default=0.8, help='The transparency of bboxes')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='whether to show img')
args = parser.parse_args()
return args
def main():
args = parse_args()
# build the model from a config file and a checkpoint file
if args.draw_heatmap:
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))
else:
cfg_options = None
model = init_model(
args.config,
args.checkpoint,
device=args.device,
cfg_options=cfg_options)
# init visualizer
model.cfg.visualizer.radius = args.radius
model.cfg.visualizer.alpha = args.alpha
model.cfg.visualizer.line_width = args.thickness
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.set_dataset_meta(
model.dataset_meta, skeleton_style=args.skeleton_style)
# inference a single image
batch_results = inference_topdown(model, args.img)
results = merge_data_samples(batch_results)
# show the results
img = imread(args.img, channel_order='rgb')
visualizer.add_datasample(
'result',
img,
data_sample=results,
draw_gt=False,
draw_bbox=True,
kpt_thr=args.kpt_thr,
draw_heatmap=args.draw_heatmap,
show_kpt_idx=args.show_kpt_idx,
skeleton_style=args.skeleton_style,
show=args.show,
out_file=args.out_file)
if __name__ == '__main__':
main()
|