import argparse import onnx import os import requests import shutil import subprocess import sys import torch from onnxruntime_genai.models.builder import create_model from PIL import Image from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM def build_vision(args): # Single image: prompt = f"{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}" url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(prompt, image, return_tensors="pt").to(args.execution_provider.replace("dml", "cuda")) inputs["pixel_values"] = inputs["pixel_values"].to(args.precision) # TorchScript export dummy_inputs = ( inputs["pixel_values"], # input_embeds: Optional[torch.FloatTensor] = None, inputs["image_sizes"], # image_sizes: Optional[torch.FloatTensor] = None, ) dynamic_axes = { "pixel_values": {0: "num_images", 1: "max_num_crops", 3: "height", 4: "width"}, "image_sizes": {0: "num_images"}, "visual_features": {0: "batch_size", 1: "num_img_tokens"}, } filename = "phi-3-v-128k-instruct-vision.onnx" temp_folder_1 = os.path.join(args.output, "vision_init_export") os.makedirs(temp_folder_1, exist_ok=True) fpath_1 = os.path.join(temp_folder_1, filename) torch.onnx.export( model.model.vision_embed_tokens, args=dummy_inputs, f=fpath_1, export_params=True, input_names=["pixel_values", "image_sizes"], output_names=["visual_features"], dynamic_axes=dynamic_axes, opset_version=14, do_constant_folding=True, ) onnx.checker.check_model(fpath_1) onnx.shape_inference.infer_shapes_path(fpath_1) onnx_model = onnx.load_model(fpath_1, load_external_data=True) temp_folder_2 = os.path.join(args.output, "vision_after_export") os.makedirs(temp_folder_2, exist_ok=True) fpath_2 = os.path.join(temp_folder_2, filename) onnx.save_model( onnx_model, fpath_2, save_as_external_data=True, all_tensors_to_one_file=True, location=f"{filename}.data", size_threshold=0, convert_attribute=False, ) shutil.rmtree(temp_folder_1) # ORT transformer optimizer temp_folder_3 = os.path.join(args.output, "vision_after_opt") fpath_3 = os.path.join(temp_folder_3, filename) subprocess.run( [ f"{sys.executable}", "-m", "onnxruntime.transformers.optimizer", "--input", fpath_2, "--output", fpath_3, "--model_type", "clip", "--num_heads", str(16), "--hidden_size", str(1024), "--use_external_data_format", "--opt_level", str(0), ] ) shutil.rmtree(temp_folder_2) # ORT 4-bits quantizer fpath_4 = os.path.join(args.output, filename) cmd = [ f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer", "--input_model", fpath_3, "--output_model", fpath_4, "--block_size", str(32), ] if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)]) subprocess.run(cmd) shutil.rmtree(temp_folder_3) def build_text_embedding(args): ######################################### # Functions/variables from model builder ######################################### from onnx import helper, numpy_helper, TensorProto, external_data_helper, save_model import numpy as np # User inputs io_dtype = TensorProto.FLOAT16 if args.precision == torch.float16 else TensorProto.FLOAT os.makedirs(args.cache_dir, exist_ok=True) # Map TensorProto dtypes to_torch_dtype = { TensorProto.FLOAT16: torch.float16, TensorProto.FLOAT: torch.float32, } to_numpy_dtype = { TensorProto.FLOAT16: np.float16, TensorProto.FLOAT: np.float32, } def make_external_tensor(np_data, name, **kwargs): tensor = numpy_helper.from_array(np_data) tensor.name = name filename = f"{name}.bin" external_data_helper.set_external_data(tensor, location=filename) with open(os.path.join(args.cache_dir, filename), "wb") as f: f.write(tensor.raw_data) tensor.ClearField("raw_data") tensor.data_location = TensorProto.EXTERNAL return tensor # Make model global model embedding = model.model.embed_tokens.weight.to(to_torch_dtype[io_dtype]).detach().cpu().numpy() weight_name = "model.embed_tokens.weight" embed_weight = make_external_tensor(embedding.astype(to_numpy_dtype[io_dtype]), weight_name) model = helper.make_model( opset_imports=[helper.make_operatorsetid('', 14), helper.make_operatorsetid('com.microsoft', 1)], ir_version=7, producer_name="onnxruntime-genai", producer_version="0.0.0", graph=helper.make_graph( name="main_graph", inputs=[helper.make_tensor_value_info("input_ids", TensorProto.INT64, shape=["batch_size", "sequence_length"])], outputs=[helper.make_tensor_value_info("inputs_embeds", io_dtype, shape=["batch_size", "sequence_length", config.hidden_size])], initializer=[embed_weight], value_info=[], nodes=[helper.make_node('Gather', inputs=[weight_name, 'input_ids'], outputs=['inputs_embeds'], name="/model/embed_tokens/Gather")], ) ) external_data_helper.load_external_data_for_model(model, args.cache_dir) # Delete external data files on disk before re-saving for path in os.listdir(args.cache_dir): if path.endswith(".bin"): os.remove(os.path.join(args.cache_dir, path)) # Delete temporary cache dir if empty if len(os.listdir(args.cache_dir)) == 0: os.rmdir(args.cache_dir) # Save ONNX model with only one external data file and delete any existing duplicate copies filename = "phi-3-v-128k-instruct-text-embedding.onnx" output_path = os.path.join(args.output, filename) save_model( model, output_path, save_as_external_data=True, all_tensors_to_one_file=True, location=f"{filename}.data", size_threshold=0, convert_attribute=False, ) def build_text(args): # Create ONNX model model_name = None precision = "int4" extra_options = { "exclude_embeds": "true", "filename": "phi-3-v-128k-instruct-text.onnx", } if args.precision == torch.float32: extra_options["int4_accuracy_level"] = 4 create_model(model_name, args.input, args.output, precision, args.execution_provider, args.cache_dir, **extra_options) def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "-i", "--input", required=True, help="Path to folder on disk containing the Hugging Face config, model, tokenizer, etc.", ) parser.add_argument( "-o", "--output", required=True, help="Path to folder to store ONNX model and additional files (e.g. GenAI config, external data files, etc.)", ) parser.add_argument( "-p", "--precision", required=True, choices=["fp16", "fp32"], help="Precision to export PyTorch components with", ) parser.add_argument( "-e", "--execution_provider", required=True, choices=["cpu", "cuda", "dml"], help="Execution provider for Phi-3 vision components", ) parser.add_argument( "-c", "--cache_dir", required=False, default=os.path.join('.', 'cache_dir'), help="Cache directory for Hugging Face files and temporary ONNX external data files", ) args = parser.parse_args() args.precision = torch.float16 if args.precision == "fp16" else torch.float32 return args if __name__ == "__main__": user_prompt = '<|user|>\n' assistant_prompt = '<|assistant|>\n' prompt_suffix = "<|end|>\n" args = get_args() config = AutoConfig.from_pretrained(args.input, trust_remote_code=True) processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider.replace("dml", "cuda")) # Build model components build_vision(args) build_text_embedding(args) build_text(args)