import os import torch import onnx import onnxslim import json from transformers import AutoProcessor, PaliGemmaForConditionalGeneration from torch import nn from huggingface_hub import HfApi, HfFolder, Repository # Constants model_id = "google/paligemma2-3b-pt-448" OUTPUT_FOLDER = os.path.join("output", model_id) TEXT_MODEL_NAME = "decoder_model_merged.onnx" VISION_MODEL_NAME = "vision_encoder.onnx" EMBED_MODEL_NAME = "embed_tokens.onnx" TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp") FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx") # Model and Processor Loading model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval() processor = AutoProcessor.from_pretrained(model_id) # Define Vision Encoder Model class VisionEncoder(nn.Module): def __init__(self, paligemma_model): super().__init__() self.config = paligemma_model.config self.vision_tower = paligemma_model.vision_tower self.multi_modal_projector = paligemma_model.multi_modal_projector def forward(self, pixel_values: torch.FloatTensor): image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features vision_model = VisionEncoder(model) # Export Models (ONNX) # Dummy text inputs (needed for exporting the text model) batch_size = 2 sequence_length = 32 inputs_embeds = torch.randn((batch_size, sequence_length, model.config.text_config.hidden_size)) position_ids = torch.arange(1, sequence_length + 1, dtype=torch.int64).expand(batch_size, sequence_length) dummy_past_key_values_kwargs = { f"past_key_values.{i}.{key}": torch.zeros(batch_size, model.config.text_config.num_key_value_heads, 8, model.config.text_config.head_dim, dtype=torch.float32) for i in range(model.config.text_config.num_hidden_layers) for key in ["key", "value"] } text_inputs_positional = tuple([inputs_embeds, position_ids] + list(dummy_past_key_values_kwargs.values())) # Export text model TEXT_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME) torch.onnx.export( model, args=text_inputs_positional, f=TEXT_MODEL_OUTPUT_PATH, export_params=True, opset_version=14, do_constant_folding=True, input_names=list(dummy_past_key_values_kwargs.keys()) + ['inputs_embeds', 'position_ids'], output_names=["logits"] + [f"present.{i}.{key}" for i in range(model.config.text_config.num_hidden_layers) for key in ["key", "value"]], dynamic_axes={ "inputs_embeds": {0: "batch_size", 1: "sequence_length"}, "position_ids": {0: "batch_size", 1: "sequence_length"}, **{ f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"} for i in range(model.config.text_config.num_hidden_layers) for key in ["key", "value"] }, "logits": {0: "batch_size", 1: "sequence_length"}, **{ f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"} for i in range(model.config.text_config.num_hidden_layers) for key in ["key", "value"] }, } ) # Export vision model VISION_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, VISION_MODEL_NAME) torch.onnx.export( vision_model, args=(torch.randn(2, 3, 224, 224),), # Dummy input f=VISION_MODEL_OUTPUT_PATH, export_params=True, opset_version=14, do_constant_folding=True, input_names=['pixel_values'], output_names=['image_features'], dynamic_axes={'pixel_values': {0: 'batch_size'}, 'image_features': {0: 'batch_size'}} ) # Export embedding model (optional) embed_layer = model.language_model.model.embed_tokens input_ids = torch.randint(0, embed_layer.num_embeddings, (batch_size, sequence_length)) EMBED_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, EMBED_MODEL_NAME) torch.onnx.export( embed_layer, args=(input_ids,), f=EMBED_MODEL_OUTPUT_PATH, export_params=True, opset_version=14, do_constant_folding=True, input_names=['input_ids'], output_names=['inputs_embeds'], dynamic_axes={ 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, 'inputs_embeds': {0: 'batch_size', 1: 'sequence_length'} }, ) # Post-processing and Upload to Hugging Face from huggingface_hub import Repository # Create a repo on Hugging Face (if it doesn't exist) repo_name = f"paligemma-onnx-{model_id.split('/')[-1]}" api = HfApi() api.create_repo(repo_name=repo_name, exist_ok=True) # Initialize a local repo for pushing to Hugging Face repo = Repository(local_dir=repo_name, clone_from=repo_name) # Move models to repo os.makedirs(os.path.join(repo_name, "onnx"), exist_ok=True) for model_file in [TEXT_MODEL_NAME, VISION_MODEL_NAME, EMBED_MODEL_NAME]: model_file_path = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, model_file) os.rename(model_file_path, os.path.join(repo_name, "onnx", model_file)) # Copy additional files (e.g., config and tokenizer) to the repo os.makedirs(os.path.join(repo_name, "config"), exist_ok=True) os.rename(os.path.join(OUTPUT_FOLDER, "config.json"), os.path.join(repo_name, "config", "config.json")) os.rename(os.path.join(OUTPUT_FOLDER, "tokenizer.json"), os.path.join(repo_name, "tokenizer.json")) # Push to Hugging Face Model Hub repo.push_to_hub(commit_message="Upload ONNX models") # Clean up temporary files import shutil shutil.rmtree(TEMP_MODEL_OUTPUT_FOLDER)