|
import os |
|
import json |
|
import shutil |
|
|
|
from optimum.exporters.onnx import main_export |
|
import onnx |
|
from onnxconverter_common import float16 |
|
import onnxruntime as rt |
|
from onnxruntime.tools.onnx_model_utils import * |
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
import huggingface_hub |
|
|
|
def add_mean_pooling(input_model, output_model, op, IR, output_embeddings_number): |
|
model = onnx.load(input_model) |
|
model_ir8 = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) |
|
|
|
minus_one_axis = onnx.helper.make_tensor( |
|
name = "minus_one_axis", |
|
data_type = onnx.TensorProto.INT64, |
|
dims = [1], |
|
vals = [-1]) |
|
|
|
model_ir8.graph.initializer.append(minus_one_axis) |
|
|
|
mask_clip_lower_limit = onnx.helper.make_tensor( |
|
name = "mask_clip_lower_limit", |
|
data_type = onnx.TensorProto.FLOAT, |
|
dims = [1], |
|
vals = [1e-9]) |
|
|
|
model_ir8.graph.initializer.append(mask_clip_lower_limit) |
|
|
|
sum_one_axis = onnx.helper.make_tensor( |
|
name = "sum_one_axis", |
|
data_type = onnx.TensorProto.INT64, |
|
dims = [1], |
|
vals = [1]) |
|
|
|
model_ir8.graph.initializer.append(sum_one_axis) |
|
|
|
attention_mask_cast_op = onnx.helper.make_node( |
|
"Cast", |
|
inputs=["attention_mask"], |
|
outputs=["attention_mask_fp32"], |
|
to=onnx.TensorProto.FLOAT |
|
) |
|
|
|
model_ir8.graph.node.append(attention_mask_cast_op) |
|
|
|
expand_dims_op = onnx.helper.make_node( |
|
"Unsqueeze", |
|
inputs=["attention_mask_fp32", "minus_one_axis"], |
|
outputs=["unsqueezed_attention_mask"], |
|
) |
|
|
|
model_ir8.graph.node.append(expand_dims_op) |
|
|
|
shape_op = onnx.helper.make_node( |
|
"Shape", |
|
inputs = ["last_hidden_state"], |
|
outputs = ["last_hidden_state_shape"] |
|
) |
|
|
|
model_ir8.graph.node.append(shape_op) |
|
|
|
broadcast_to_op = onnx.helper.make_node( |
|
"Expand", |
|
inputs=["unsqueezed_attention_mask", "last_hidden_state_shape"], |
|
outputs=["expanded_attention_mask"], |
|
) |
|
|
|
model_ir8.graph.node.append(broadcast_to_op) |
|
|
|
multiply_op = onnx.helper.make_node( |
|
"Mul", |
|
inputs=["last_hidden_state", "expanded_attention_mask"], |
|
outputs=["last_hidden_state_x_expanded_attention_mask"], |
|
) |
|
|
|
model_ir8.graph.node.append(multiply_op) |
|
|
|
sum_embeddings_op = onnx.helper.make_node( |
|
"ReduceSum", |
|
inputs=["last_hidden_state_x_expanded_attention_mask", "sum_one_axis"], |
|
outputs=["sum_last_hidden_state_x_expanded_attention_mask"], |
|
) |
|
|
|
model_ir8.graph.node.append(sum_embeddings_op) |
|
|
|
sum_mask_op = onnx.helper.make_node( |
|
"ReduceSum", |
|
inputs=["expanded_attention_mask", "sum_one_axis"], |
|
outputs=["sum_expanded_attention_mask"], |
|
) |
|
|
|
model_ir8.graph.node.append(sum_mask_op) |
|
|
|
clip_mask_op = onnx.helper.make_node( |
|
"Clip", |
|
inputs=["sum_expanded_attention_mask", "mask_clip_lower_limit"], |
|
outputs=["clipped_sum_expanded_attention_mask"], |
|
) |
|
|
|
model_ir8.graph.node.append(clip_mask_op) |
|
|
|
pooled_embeddings_op = onnx.helper.make_node( |
|
"Div", |
|
inputs=["sum_last_hidden_state_x_expanded_attention_mask", "clipped_sum_expanded_attention_mask"], |
|
outputs=["pooled_embeddings"], |
|
|
|
) |
|
|
|
model_ir8.graph.node.append(pooled_embeddings_op) |
|
|
|
squeeze_pooled_embeddings_op = onnx.helper.make_node( |
|
"Squeeze", |
|
inputs=["pooled_embeddings", "sum_one_axis"], |
|
outputs=["squeezed_pooled_embeddings"] |
|
|
|
) |
|
|
|
model_ir8.graph.node.append(squeeze_pooled_embeddings_op) |
|
|
|
normalized_pooled_embeddings_op = onnx.helper.make_node( |
|
"Normalizer", |
|
domain="ai.onnx.ml", |
|
inputs=["squeezed_pooled_embeddings"], |
|
outputs=["sentence_embedding"], |
|
norm = "L2" |
|
) |
|
|
|
|
|
model_ir8.graph.node.append(normalized_pooled_embeddings_op) |
|
|
|
sentence_embeddings_output = onnx.helper.make_tensor_value_info( |
|
"sentence_embedding", |
|
onnx.TensorProto.FLOAT, |
|
shape=["batch_size", output_embeddings_number] |
|
) |
|
|
|
model_ir8.graph.output.append(sentence_embeddings_output) |
|
|
|
for node in model_ir8.graph.output: |
|
if node.name == "last_hidden_state": |
|
model_ir8.graph.output.remove(node) |
|
|
|
model_ir8 = onnx.helper.make_model(model_ir8.graph, ir_version = 8, opset_imports = [op]) |
|
|
|
onnx.save(model_ir8, output_model, save_as_external_data = False) |
|
|
|
|
|
|
|
with open('conversion_config.json') as json_file: |
|
conversion_config = json.load(json_file) |
|
|
|
|
|
model_id = conversion_config["model_id"] |
|
number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"] |
|
precision_to_filename_map = conversion_config["precision_to_filename_map"] |
|
opset = conversion_config["opset"] |
|
IR = conversion_config["IR"] |
|
|
|
|
|
op = onnx.OperatorSetIdProto() |
|
op.version = opset |
|
|
|
|
|
if not os.path.exists("onnx"): |
|
os.makedirs("onnx") |
|
|
|
print("Exporting the main model version") |
|
try: |
|
main_export(model_name_or_path=model_id, output="./", opset=opset, trust_remote_code=True, task="feature-extraction", dtype="fp32") |
|
except: |
|
huggingface_hub.hf_hub_download(repo_id=model_id, filename="model.onnx", local_dir="./") |
|
|
|
|
|
if "fp32" in precision_to_filename_map: |
|
print("Exporting the fp32 onnx file...") |
|
|
|
shutil.copyfile('model.onnx', precision_to_filename_map["fp32"]) |
|
add_mean_pooling("model.onnx", precision_to_filename_map["fp32"], op, IR, number_of_generated_embeddings) |
|
|
|
print("Done\n\n") |
|
|
|
if "int8" in precision_to_filename_map: |
|
print("Quantizing fp32 model to int8...") |
|
quantize_dynamic("model.onnx", precision_to_filename_map["int8"], weight_type=QuantType.QInt8) |
|
add_mean_pooling( precision_to_filename_map["int8"], precision_to_filename_map["int8"], op, IR, number_of_generated_embeddings) |
|
print("Done\n\n") |
|
|
|
if "uint8" in precision_to_filename_map: |
|
print("Quantizing fp32 model to uint8...") |
|
quantize_dynamic("model.onnx", precision_to_filename_map["uint8"], weight_type=QuantType.QUInt8) |
|
add_mean_pooling( precision_to_filename_map["uint8"], precision_to_filename_map["uint8"], op, IR, number_of_generated_embeddings) |
|
print("Done\n\n") |
|
|
|
os.remove("model.onnx") |
|
|