Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
from models.with_mobilenet import PoseEstimationWithMobileNet | |
from modules.load_state import load_state | |
def convert_to_onnx(net, output_name): | |
input = torch.randn(1, 3, 256, 456) | |
input_names = ['data'] | |
output_names = ['stage_0_output_1_heatmaps', 'stage_0_output_0_pafs', | |
'stage_1_output_1_heatmaps', 'stage_1_output_0_pafs'] | |
torch.onnx.export(net, input, output_name, verbose=True, input_names=input_names, output_names=output_names) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--checkpoint-path', type=str, required=True, help='path to the checkpoint') | |
parser.add_argument('--output-name', type=str, default='human-pose-estimation.onnx', | |
help='name of output model in ONNX format') | |
args = parser.parse_args() | |
net = PoseEstimationWithMobileNet() | |
checkpoint = torch.load(args.checkpoint_path) | |
load_state(net, checkpoint) | |
convert_to_onnx(net, args.output_name) | |