import os import json import shutil from optimum.exporters.onnx import main_export import onnx from onnxconverter_common import float16 import onnxruntime as rt from onnxruntime.tools.onnx_model_utils import * from onnxruntime.quantization import quantize_dynamic, QuantType import huggingface_hub def add_mean_pooling(input_model, output_model, op, IR, output_embeddings_number): model = onnx.load(input_model) model_ir8 = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version minus_one_axis = onnx.helper.make_tensor( name = "minus_one_axis", data_type = onnx.TensorProto.INT64, dims = [1], vals = [-1]) model_ir8.graph.initializer.append(minus_one_axis) mask_clip_lower_limit = onnx.helper.make_tensor( name = "mask_clip_lower_limit", data_type = onnx.TensorProto.FLOAT, dims = [1], vals = [1e-9]) model_ir8.graph.initializer.append(mask_clip_lower_limit) sum_one_axis = onnx.helper.make_tensor( name = "sum_one_axis", data_type = onnx.TensorProto.INT64, dims = [1], vals = [1]) model_ir8.graph.initializer.append(sum_one_axis) attention_mask_cast_op = onnx.helper.make_node( "Cast", inputs=["attention_mask"], outputs=["attention_mask_fp32"], to=onnx.TensorProto.FLOAT ) model_ir8.graph.node.append(attention_mask_cast_op) expand_dims_op = onnx.helper.make_node( "Unsqueeze", inputs=["attention_mask_fp32", "minus_one_axis"], outputs=["unsqueezed_attention_mask"], ) model_ir8.graph.node.append(expand_dims_op) shape_op = onnx.helper.make_node( "Shape", inputs = ["last_hidden_state"], outputs = ["last_hidden_state_shape"] ) model_ir8.graph.node.append(shape_op) broadcast_to_op = onnx.helper.make_node( "Expand", inputs=["unsqueezed_attention_mask", "last_hidden_state_shape"], outputs=["expanded_attention_mask"], ) model_ir8.graph.node.append(broadcast_to_op) multiply_op = onnx.helper.make_node( "Mul", inputs=["last_hidden_state", "expanded_attention_mask"], outputs=["last_hidden_state_x_expanded_attention_mask"], ) model_ir8.graph.node.append(multiply_op) sum_embeddings_op = onnx.helper.make_node( "ReduceSum", inputs=["last_hidden_state_x_expanded_attention_mask", "sum_one_axis"], outputs=["sum_last_hidden_state_x_expanded_attention_mask"], ) model_ir8.graph.node.append(sum_embeddings_op) sum_mask_op = onnx.helper.make_node( "ReduceSum", inputs=["expanded_attention_mask", "sum_one_axis"], outputs=["sum_expanded_attention_mask"], ) model_ir8.graph.node.append(sum_mask_op) clip_mask_op = onnx.helper.make_node( "Clip", inputs=["sum_expanded_attention_mask", "mask_clip_lower_limit"], outputs=["clipped_sum_expanded_attention_mask"], ) model_ir8.graph.node.append(clip_mask_op) pooled_embeddings_op = onnx.helper.make_node( "Div", inputs=["sum_last_hidden_state_x_expanded_attention_mask", "clipped_sum_expanded_attention_mask"], outputs=["pooled_embeddings"], # outputs=["sentence_embeddings"] ) model_ir8.graph.node.append(pooled_embeddings_op) squeeze_pooled_embeddings_op = onnx.helper.make_node( "Squeeze", inputs=["pooled_embeddings", "sum_one_axis"], outputs=["squeezed_pooled_embeddings"] ) model_ir8.graph.node.append(squeeze_pooled_embeddings_op) normalized_pooled_embeddings_op = onnx.helper.make_node( "Normalizer", domain="ai.onnx.ml", inputs=["squeezed_pooled_embeddings"], outputs=["sentence_embedding"], norm = "L2" ) model_ir8.graph.node.append(normalized_pooled_embeddings_op) sentence_embeddings_output = onnx.helper.make_tensor_value_info( "sentence_embedding", onnx.TensorProto.FLOAT, shape=["batch_size", output_embeddings_number] ) model_ir8.graph.output.append(sentence_embeddings_output) for node in model_ir8.graph.output: if node.name == "last_hidden_state": model_ir8.graph.output.remove(node) model_ir8 = onnx.helper.make_model(model_ir8.graph, ir_version = 8, opset_imports = [op]) #to be sure that we have compatible opset and IR version onnx.save(model_ir8, output_model, save_as_external_data = False) with open('conversion_config.json') as json_file: conversion_config = json.load(json_file) model_id = conversion_config["model_id"] number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"] precision_to_filename_map = conversion_config["precision_to_filename_map"] opset = conversion_config["opset"] IR = conversion_config["IR"] op = onnx.OperatorSetIdProto() op.version = opset if not os.path.exists("onnx"): os.makedirs("onnx") print("Exporting the main model version") try: main_export(model_name_or_path=model_id, output="./", opset=opset, trust_remote_code=True, task="feature-extraction", dtype="fp32") except: huggingface_hub.hf_hub_download(repo_id=model_id, filename="model.onnx", local_dir="./") if "fp32" in precision_to_filename_map: print("Exporting the fp32 onnx file...") shutil.copyfile('model.onnx', precision_to_filename_map["fp32"]) add_mean_pooling("model.onnx", precision_to_filename_map["fp32"], op, IR, number_of_generated_embeddings) print("Done\n\n") if "int8" in precision_to_filename_map: print("Quantizing fp32 model to int8...") quantize_dynamic("model.onnx", precision_to_filename_map["int8"], weight_type=QuantType.QInt8) add_mean_pooling( precision_to_filename_map["int8"], precision_to_filename_map["int8"], op, IR, number_of_generated_embeddings) print("Done\n\n") if "uint8" in precision_to_filename_map: print("Quantizing fp32 model to uint8...") quantize_dynamic("model.onnx", precision_to_filename_map["uint8"], weight_type=QuantType.QUInt8) add_mean_pooling( precision_to_filename_map["uint8"], precision_to_filename_map["uint8"], op, IR, number_of_generated_embeddings) print("Done\n\n") os.remove("model.onnx")