martinhillebrandtd's picture
model
cea8feb
import os
import json
import shutil
from optimum.exporters.onnx import main_export
import onnx
from onnxconverter_common import float16
import onnxruntime as rt
from onnxruntime.tools.onnx_model_utils import *
from onnxruntime.quantization import quantize_dynamic, QuantType
from huggingface_hub import hf_hub_download
with open('conversion_config.json') as json_file:
conversion_config = json.load(json_file)
model_id = conversion_config["model_id"]
number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"]
precision_to_filename_map = conversion_config["precision_to_filename_map"]
opset = conversion_config["opset"]
IR = conversion_config["IR"]
op = onnx.OperatorSetIdProto()
op.version = opset
if not os.path.exists("onnx"):
os.makedirs("onnx")
if "fp32" in precision_to_filename_map:
print("Exporting the fp32 onnx file...")
filename = precision_to_filename_map['fp32']
hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./")
model = onnx.load(filename)
model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version
onnx.save(model_fixed, filename)
print("Done\n\n")
if "int8" in precision_to_filename_map:
print("Exporting the int8 onnx file...")
filename = precision_to_filename_map['int8']
hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./")
model = onnx.load(filename)
model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version
onnx.save(model_fixed, filename)
print("Done\n\n")
if "uint8" in precision_to_filename_map:
print("Exporting the uint8 onnx file...")
filename = precision_to_filename_map['uint8']
hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./")
model = onnx.load(filename)
model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version
onnx.save(model_fixed, filename)
print("Done\n\n")