|
import os |
|
import torch |
|
import onnx |
|
import onnxslim |
|
import json |
|
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration |
|
from torch import nn |
|
from huggingface_hub import HfApi, HfFolder, Repository |
|
|
|
|
|
model_id = "google/paligemma2-3b-pt-448" |
|
OUTPUT_FOLDER = os.path.join("output", model_id) |
|
TEXT_MODEL_NAME = "decoder_model_merged.onnx" |
|
VISION_MODEL_NAME = "vision_encoder.onnx" |
|
EMBED_MODEL_NAME = "embed_tokens.onnx" |
|
TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp") |
|
FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx") |
|
|
|
|
|
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval() |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
class VisionEncoder(nn.Module): |
|
def __init__(self, paligemma_model): |
|
super().__init__() |
|
self.config = paligemma_model.config |
|
self.vision_tower = paligemma_model.vision_tower |
|
self.multi_modal_projector = paligemma_model.multi_modal_projector |
|
|
|
def forward(self, pixel_values: torch.FloatTensor): |
|
image_outputs = self.vision_tower(pixel_values) |
|
selected_image_feature = image_outputs.last_hidden_state |
|
image_features = self.multi_modal_projector(selected_image_feature) |
|
image_features = image_features / (self.config.text_config.hidden_size**0.5) |
|
return image_features |
|
|
|
vision_model = VisionEncoder(model) |
|
|
|
|
|
|
|
|
|
batch_size = 2 |
|
sequence_length = 32 |
|
inputs_embeds = torch.randn((batch_size, sequence_length, model.config.text_config.hidden_size)) |
|
position_ids = torch.arange(1, sequence_length + 1, dtype=torch.int64).expand(batch_size, sequence_length) |
|
dummy_past_key_values_kwargs = { |
|
f"past_key_values.{i}.{key}": torch.zeros(batch_size, model.config.text_config.num_key_value_heads, 8, model.config.text_config.head_dim, dtype=torch.float32) |
|
for i in range(model.config.text_config.num_hidden_layers) |
|
for key in ["key", "value"] |
|
} |
|
|
|
text_inputs_positional = tuple([inputs_embeds, position_ids] + list(dummy_past_key_values_kwargs.values())) |
|
|
|
|
|
TEXT_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME) |
|
torch.onnx.export( |
|
model, |
|
args=text_inputs_positional, |
|
f=TEXT_MODEL_OUTPUT_PATH, |
|
export_params=True, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
input_names=list(dummy_past_key_values_kwargs.keys()) + ['inputs_embeds', 'position_ids'], |
|
output_names=["logits"] + [f"present.{i}.{key}" for i in range(model.config.text_config.num_hidden_layers) for key in ["key", "value"]], |
|
dynamic_axes={ |
|
"inputs_embeds": {0: "batch_size", 1: "sequence_length"}, |
|
"position_ids": {0: "batch_size", 1: "sequence_length"}, |
|
**{ |
|
f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"} |
|
for i in range(model.config.text_config.num_hidden_layers) |
|
for key in ["key", "value"] |
|
}, |
|
"logits": {0: "batch_size", 1: "sequence_length"}, |
|
**{ |
|
f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"} |
|
for i in range(model.config.text_config.num_hidden_layers) |
|
for key in ["key", "value"] |
|
}, |
|
} |
|
) |
|
|
|
|
|
VISION_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, VISION_MODEL_NAME) |
|
torch.onnx.export( |
|
vision_model, |
|
args=(torch.randn(2, 3, 224, 224),), |
|
f=VISION_MODEL_OUTPUT_PATH, |
|
export_params=True, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
input_names=['pixel_values'], |
|
output_names=['image_features'], |
|
dynamic_axes={'pixel_values': {0: 'batch_size'}, 'image_features': {0: 'batch_size'}} |
|
) |
|
|
|
|
|
embed_layer = model.language_model.model.embed_tokens |
|
input_ids = torch.randint(0, embed_layer.num_embeddings, (batch_size, sequence_length)) |
|
EMBED_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, EMBED_MODEL_NAME) |
|
torch.onnx.export( |
|
embed_layer, |
|
args=(input_ids,), |
|
f=EMBED_MODEL_OUTPUT_PATH, |
|
export_params=True, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
input_names=['input_ids'], |
|
output_names=['inputs_embeds'], |
|
dynamic_axes={ |
|
'input_ids': {0: 'batch_size', 1: 'sequence_length'}, |
|
'inputs_embeds': {0: 'batch_size', 1: 'sequence_length'} |
|
}, |
|
) |
|
|
|
|
|
from huggingface_hub import Repository |
|
|
|
|
|
repo_name = f"paligemma-onnx-{model_id.split('/')[-1]}" |
|
api = HfApi() |
|
api.create_repo(repo_name=repo_name, exist_ok=True) |
|
|
|
|
|
repo = Repository(local_dir=repo_name, clone_from=repo_name) |
|
|
|
|
|
os.makedirs(os.path.join(repo_name, "onnx"), exist_ok=True) |
|
for model_file in [TEXT_MODEL_NAME, VISION_MODEL_NAME, EMBED_MODEL_NAME]: |
|
model_file_path = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, model_file) |
|
os.rename(model_file_path, os.path.join(repo_name, "onnx", model_file)) |
|
|
|
|
|
os.makedirs(os.path.join(repo_name, "config"), exist_ok=True) |
|
os.rename(os.path.join(OUTPUT_FOLDER, "config.json"), os.path.join(repo_name, "config", "config.json")) |
|
os.rename(os.path.join(OUTPUT_FOLDER, "tokenizer.json"), os.path.join(repo_name, "tokenizer.json")) |
|
|
|
|
|
repo.push_to_hub(commit_message="Upload ONNX models") |
|
|
|
|
|
import shutil |
|
shutil.rmtree(TEMP_MODEL_OUTPUT_FOLDER) |
|
|