|
|
|
from argparse import ArgumentParser |
|
|
|
from mmcv.image import imread |
|
|
|
from mmpose.apis import inference_topdown, init_model |
|
from mmpose.registry import VISUALIZERS |
|
from mmpose.structures import merge_data_samples |
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument('img', help='Image file') |
|
parser.add_argument('config', help='Config file') |
|
parser.add_argument('checkpoint', help='Checkpoint file') |
|
parser.add_argument('--out-file', default=None, help='Path to output file') |
|
parser.add_argument( |
|
'--device', default='cuda:0', help='Device used for inference') |
|
parser.add_argument( |
|
'--draw-heatmap', |
|
action='store_true', |
|
help='Visualize the predicted heatmap') |
|
parser.add_argument( |
|
'--show-kpt-idx', |
|
action='store_true', |
|
default=False, |
|
help='Whether to show the index of keypoints') |
|
parser.add_argument( |
|
'--skeleton-style', |
|
default='mmpose', |
|
type=str, |
|
choices=['mmpose', 'openpose'], |
|
help='Skeleton style selection') |
|
parser.add_argument( |
|
'--kpt-thr', |
|
type=float, |
|
default=0.3, |
|
help='Visualizing keypoint thresholds') |
|
parser.add_argument( |
|
'--radius', |
|
type=int, |
|
default=3, |
|
help='Keypoint radius for visualization') |
|
parser.add_argument( |
|
'--thickness', |
|
type=int, |
|
default=1, |
|
help='Link thickness for visualization') |
|
parser.add_argument( |
|
'--alpha', type=float, default=0.8, help='The transparency of bboxes') |
|
parser.add_argument( |
|
'--show', |
|
action='store_true', |
|
default=False, |
|
help='whether to show img') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
|
|
if args.draw_heatmap: |
|
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True))) |
|
else: |
|
cfg_options = None |
|
|
|
model = init_model( |
|
args.config, |
|
args.checkpoint, |
|
device=args.device, |
|
cfg_options=cfg_options) |
|
|
|
|
|
model.cfg.visualizer.radius = args.radius |
|
model.cfg.visualizer.alpha = args.alpha |
|
model.cfg.visualizer.line_width = args.thickness |
|
|
|
visualizer = VISUALIZERS.build(model.cfg.visualizer) |
|
visualizer.set_dataset_meta( |
|
model.dataset_meta, skeleton_style=args.skeleton_style) |
|
|
|
|
|
batch_results = inference_topdown(model, args.img) |
|
results = merge_data_samples(batch_results) |
|
|
|
|
|
img = imread(args.img, channel_order='rgb') |
|
visualizer.add_datasample( |
|
'result', |
|
img, |
|
data_sample=results, |
|
draw_gt=False, |
|
draw_bbox=True, |
|
kpt_thr=args.kpt_thr, |
|
draw_heatmap=args.draw_heatmap, |
|
show_kpt_idx=args.show_kpt_idx, |
|
skeleton_style=args.skeleton_style, |
|
show=args.show, |
|
out_file=args.out_file) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|