ONNX / app.py
NSTiwari's picture
Update app.py
7685097 verified
import os
import torch
import onnx
import onnxslim
import json
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from torch import nn
from huggingface_hub import HfApi, HfFolder, Repository
# Constants
model_id = "google/paligemma2-3b-pt-448"
OUTPUT_FOLDER = os.path.join("output", model_id)
TEXT_MODEL_NAME = "decoder_model_merged.onnx"
VISION_MODEL_NAME = "vision_encoder.onnx"
EMBED_MODEL_NAME = "embed_tokens.onnx"
TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp")
FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx")
# Model and Processor Loading
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)
# Define Vision Encoder Model
class VisionEncoder(nn.Module):
def __init__(self, paligemma_model):
super().__init__()
self.config = paligemma_model.config
self.vision_tower = paligemma_model.vision_tower
self.multi_modal_projector = paligemma_model.multi_modal_projector
def forward(self, pixel_values: torch.FloatTensor):
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.text_config.hidden_size**0.5)
return image_features
vision_model = VisionEncoder(model)
# Export Models (ONNX)
# Dummy text inputs (needed for exporting the text model)
batch_size = 2
sequence_length = 32
inputs_embeds = torch.randn((batch_size, sequence_length, model.config.text_config.hidden_size))
position_ids = torch.arange(1, sequence_length + 1, dtype=torch.int64).expand(batch_size, sequence_length)
dummy_past_key_values_kwargs = {
f"past_key_values.{i}.{key}": torch.zeros(batch_size, model.config.text_config.num_key_value_heads, 8, model.config.text_config.head_dim, dtype=torch.float32)
for i in range(model.config.text_config.num_hidden_layers)
for key in ["key", "value"]
}
text_inputs_positional = tuple([inputs_embeds, position_ids] + list(dummy_past_key_values_kwargs.values()))
# Export text model
TEXT_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME)
torch.onnx.export(
model,
args=text_inputs_positional,
f=TEXT_MODEL_OUTPUT_PATH,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=list(dummy_past_key_values_kwargs.keys()) + ['inputs_embeds', 'position_ids'],
output_names=["logits"] + [f"present.{i}.{key}" for i in range(model.config.text_config.num_hidden_layers) for key in ["key", "value"]],
dynamic_axes={
"inputs_embeds": {0: "batch_size", 1: "sequence_length"},
"position_ids": {0: "batch_size", 1: "sequence_length"},
**{
f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"}
for i in range(model.config.text_config.num_hidden_layers)
for key in ["key", "value"]
},
"logits": {0: "batch_size", 1: "sequence_length"},
**{
f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"}
for i in range(model.config.text_config.num_hidden_layers)
for key in ["key", "value"]
},
}
)
# Export vision model
VISION_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, VISION_MODEL_NAME)
torch.onnx.export(
vision_model,
args=(torch.randn(2, 3, 224, 224),), # Dummy input
f=VISION_MODEL_OUTPUT_PATH,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['pixel_values'],
output_names=['image_features'],
dynamic_axes={'pixel_values': {0: 'batch_size'}, 'image_features': {0: 'batch_size'}}
)
# Export embedding model (optional)
embed_layer = model.language_model.model.embed_tokens
input_ids = torch.randint(0, embed_layer.num_embeddings, (batch_size, sequence_length))
EMBED_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, EMBED_MODEL_NAME)
torch.onnx.export(
embed_layer,
args=(input_ids,),
f=EMBED_MODEL_OUTPUT_PATH,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input_ids'],
output_names=['inputs_embeds'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'inputs_embeds': {0: 'batch_size', 1: 'sequence_length'}
},
)
# Post-processing and Upload to Hugging Face
from huggingface_hub import Repository
# Create a repo on Hugging Face (if it doesn't exist)
repo_name = f"paligemma-onnx-{model_id.split('/')[-1]}"
api = HfApi()
api.create_repo(repo_name=repo_name, exist_ok=True)
# Initialize a local repo for pushing to Hugging Face
repo = Repository(local_dir=repo_name, clone_from=repo_name)
# Move models to repo
os.makedirs(os.path.join(repo_name, "onnx"), exist_ok=True)
for model_file in [TEXT_MODEL_NAME, VISION_MODEL_NAME, EMBED_MODEL_NAME]:
model_file_path = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, model_file)
os.rename(model_file_path, os.path.join(repo_name, "onnx", model_file))
# Copy additional files (e.g., config and tokenizer) to the repo
os.makedirs(os.path.join(repo_name, "config"), exist_ok=True)
os.rename(os.path.join(OUTPUT_FOLDER, "config.json"), os.path.join(repo_name, "config", "config.json"))
os.rename(os.path.join(OUTPUT_FOLDER, "tokenizer.json"), os.path.join(repo_name, "tokenizer.json"))
# Push to Hugging Face Model Hub
repo.push_to_hub(commit_message="Upload ONNX models")
# Clean up temporary files
import shutil
shutil.rmtree(TEMP_MODEL_OUTPUT_FOLDER)