import os import json import shutil from optimum.exporters.onnx import main_export import onnx from onnxconverter_common import float16 import onnxruntime as rt from onnxruntime.tools.onnx_model_utils import * from onnxruntime.quantization import quantize_dynamic, QuantType from huggingface_hub import hf_hub_download import transformers with open('conversion_config.json') as json_file: conversion_config = json.load(json_file) model_id = conversion_config["model_id"] number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"] precision_to_filename_map = conversion_config["precision_to_filename_map"] opset = conversion_config["opset"] IR = conversion_config["IR"] op = onnx.OperatorSetIdProto() op.version = opset print("Exporting tokenizer...") tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) tokenizer.save_pretrained("./") print("Done\n\n") if not os.path.exists("onnx"): os.makedirs("onnx") if "fp32" in precision_to_filename_map: print("Exporting the fp32 onnx file...") filename = precision_to_filename_map['fp32'] hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./") model = onnx.load(filename) model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version onnx.save(model_fixed, filename) print("Done\n\n") if "int8" in precision_to_filename_map: print("Exporting the int8 onnx file...") filename = precision_to_filename_map['int8'] hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./") model = onnx.load(filename) model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version onnx.save(model_fixed, filename) print("Done\n\n") if "uint8" in precision_to_filename_map: print("Exporting the uint8 onnx file...") filename = precision_to_filename_map['uint8'] hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./") model = onnx.load(filename) model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version onnx.save(model_fixed, filename) print("Done\n\n")