martinhillebrandtd's picture
model
ff33345
raw
history blame
6.7 kB
import os
import json
import shutil
from optimum.exporters.onnx import main_export
import onnx
from onnxconverter_common import float16
import onnxruntime as rt
from onnxruntime.tools.onnx_model_utils import *
from onnxruntime.quantization import quantize_dynamic, QuantType
import huggingface_hub
def add_mean_pooling(input_model, output_model, op, IR, output_embeddings_number):
model = onnx.load(input_model)
model_ir8 = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version
minus_one_axis = onnx.helper.make_tensor(
name = "minus_one_axis",
data_type = onnx.TensorProto.INT64,
dims = [1],
vals = [-1])
model_ir8.graph.initializer.append(minus_one_axis)
mask_clip_lower_limit = onnx.helper.make_tensor(
name = "mask_clip_lower_limit",
data_type = onnx.TensorProto.FLOAT,
dims = [1],
vals = [1e-9])
model_ir8.graph.initializer.append(mask_clip_lower_limit)
sum_one_axis = onnx.helper.make_tensor(
name = "sum_one_axis",
data_type = onnx.TensorProto.INT64,
dims = [1],
vals = [1])
model_ir8.graph.initializer.append(sum_one_axis)
attention_mask_cast_op = onnx.helper.make_node(
"Cast",
inputs=["attention_mask"],
outputs=["attention_mask_fp32"],
to=onnx.TensorProto.FLOAT
)
model_ir8.graph.node.append(attention_mask_cast_op)
expand_dims_op = onnx.helper.make_node(
"Unsqueeze",
inputs=["attention_mask_fp32", "minus_one_axis"],
outputs=["unsqueezed_attention_mask"],
)
model_ir8.graph.node.append(expand_dims_op)
shape_op = onnx.helper.make_node(
"Shape",
inputs = ["last_hidden_state"],
outputs = ["last_hidden_state_shape"]
)
model_ir8.graph.node.append(shape_op)
broadcast_to_op = onnx.helper.make_node(
"Expand",
inputs=["unsqueezed_attention_mask", "last_hidden_state_shape"],
outputs=["expanded_attention_mask"],
)
model_ir8.graph.node.append(broadcast_to_op)
multiply_op = onnx.helper.make_node(
"Mul",
inputs=["last_hidden_state", "expanded_attention_mask"],
outputs=["last_hidden_state_x_expanded_attention_mask"],
)
model_ir8.graph.node.append(multiply_op)
sum_embeddings_op = onnx.helper.make_node(
"ReduceSum",
inputs=["last_hidden_state_x_expanded_attention_mask", "sum_one_axis"],
outputs=["sum_last_hidden_state_x_expanded_attention_mask"],
)
model_ir8.graph.node.append(sum_embeddings_op)
sum_mask_op = onnx.helper.make_node(
"ReduceSum",
inputs=["expanded_attention_mask", "sum_one_axis"],
outputs=["sum_expanded_attention_mask"],
)
model_ir8.graph.node.append(sum_mask_op)
clip_mask_op = onnx.helper.make_node(
"Clip",
inputs=["sum_expanded_attention_mask", "mask_clip_lower_limit"],
outputs=["clipped_sum_expanded_attention_mask"],
)
model_ir8.graph.node.append(clip_mask_op)
pooled_embeddings_op = onnx.helper.make_node(
"Div",
inputs=["sum_last_hidden_state_x_expanded_attention_mask", "clipped_sum_expanded_attention_mask"],
outputs=["pooled_embeddings"],
# outputs=["sentence_embeddings"]
)
model_ir8.graph.node.append(pooled_embeddings_op)
squeeze_pooled_embeddings_op = onnx.helper.make_node(
"Squeeze",
inputs=["pooled_embeddings", "sum_one_axis"],
outputs=["squeezed_pooled_embeddings"]
)
model_ir8.graph.node.append(squeeze_pooled_embeddings_op)
normalized_pooled_embeddings_op = onnx.helper.make_node(
"Normalizer",
domain="ai.onnx.ml",
inputs=["squeezed_pooled_embeddings"],
outputs=["sentence_embedding"],
norm = "L2"
)
model_ir8.graph.node.append(normalized_pooled_embeddings_op)
sentence_embeddings_output = onnx.helper.make_tensor_value_info(
"sentence_embedding",
onnx.TensorProto.FLOAT,
shape=["batch_size", output_embeddings_number]
)
model_ir8.graph.output.append(sentence_embeddings_output)
for node in model_ir8.graph.output:
if node.name == "last_hidden_state":
model_ir8.graph.output.remove(node)
model_ir8 = onnx.helper.make_model(model_ir8.graph, ir_version = 8, opset_imports = [op]) #to be sure that we have compatible opset and IR version
onnx.save(model_ir8, output_model, save_as_external_data = False)
with open('conversion_config.json') as json_file:
conversion_config = json.load(json_file)
model_id = conversion_config["model_id"]
number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"]
precision_to_filename_map = conversion_config["precision_to_filename_map"]
opset = conversion_config["opset"]
IR = conversion_config["IR"]
op = onnx.OperatorSetIdProto()
op.version = opset
if not os.path.exists("onnx"):
os.makedirs("onnx")
print("Exporting the main model version")
try:
main_export(model_name_or_path=model_id, output="./", opset=opset, trust_remote_code=True, task="feature-extraction", dtype="fp32")
except:
huggingface_hub.hf_hub_download(repo_id=model_id, filename="model.onnx", local_dir="./")
if "fp32" in precision_to_filename_map:
print("Exporting the fp32 onnx file...")
shutil.copyfile('model.onnx', precision_to_filename_map["fp32"])
add_mean_pooling("model.onnx", precision_to_filename_map["fp32"], op, IR, number_of_generated_embeddings)
print("Done\n\n")
if "int8" in precision_to_filename_map:
print("Quantizing fp32 model to int8...")
quantize_dynamic("model.onnx", precision_to_filename_map["int8"], weight_type=QuantType.QInt8)
add_mean_pooling( precision_to_filename_map["int8"], precision_to_filename_map["int8"], op, IR, number_of_generated_embeddings)
print("Done\n\n")
if "uint8" in precision_to_filename_map:
print("Quantizing fp32 model to uint8...")
quantize_dynamic("model.onnx", precision_to_filename_map["uint8"], weight_type=QuantType.QUInt8)
add_mean_pooling( precision_to_filename_map["uint8"], precision_to_filename_map["uint8"], op, IR, number_of_generated_embeddings)
print("Done\n\n")
os.remove("model.onnx")