File size: 5,635 Bytes
9446298
9a36f8d
 
 
 
 
 
 
 
 
7685097
9a36f8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import torch
import onnx
import onnxslim
import json
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from torch import nn
from huggingface_hub import HfApi, HfFolder, Repository

# Constants
model_id = "google/paligemma2-3b-pt-448"
OUTPUT_FOLDER = os.path.join("output", model_id)
TEXT_MODEL_NAME = "decoder_model_merged.onnx"
VISION_MODEL_NAME = "vision_encoder.onnx"
EMBED_MODEL_NAME = "embed_tokens.onnx"
TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp")
FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx")

# Model and Processor Loading
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)

# Define Vision Encoder Model
class VisionEncoder(nn.Module):
    def __init__(self, paligemma_model):
        super().__init__()
        self.config = paligemma_model.config
        self.vision_tower = paligemma_model.vision_tower
        self.multi_modal_projector = paligemma_model.multi_modal_projector

    def forward(self, pixel_values: torch.FloatTensor):
        image_outputs = self.vision_tower(pixel_values)
        selected_image_feature = image_outputs.last_hidden_state
        image_features = self.multi_modal_projector(selected_image_feature)
        image_features = image_features / (self.config.text_config.hidden_size**0.5)
        return image_features

vision_model = VisionEncoder(model)

# Export Models (ONNX)

# Dummy text inputs (needed for exporting the text model)
batch_size = 2
sequence_length = 32
inputs_embeds = torch.randn((batch_size, sequence_length, model.config.text_config.hidden_size))
position_ids = torch.arange(1, sequence_length + 1, dtype=torch.int64).expand(batch_size, sequence_length)
dummy_past_key_values_kwargs = {
    f"past_key_values.{i}.{key}": torch.zeros(batch_size, model.config.text_config.num_key_value_heads, 8, model.config.text_config.head_dim, dtype=torch.float32)
    for i in range(model.config.text_config.num_hidden_layers)
    for key in ["key", "value"]
}

text_inputs_positional = tuple([inputs_embeds, position_ids] + list(dummy_past_key_values_kwargs.values()))

# Export text model
TEXT_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME)
torch.onnx.export(
    model,
    args=text_inputs_positional,
    f=TEXT_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=list(dummy_past_key_values_kwargs.keys()) + ['inputs_embeds', 'position_ids'],
    output_names=["logits"] + [f"present.{i}.{key}" for i in range(model.config.text_config.num_hidden_layers) for key in ["key", "value"]],
    dynamic_axes={
        "inputs_embeds": {0: "batch_size", 1: "sequence_length"},
        "position_ids": {0: "batch_size", 1: "sequence_length"},
        **{
            f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"}
            for i in range(model.config.text_config.num_hidden_layers)
            for key in ["key", "value"]
        },
        "logits": {0: "batch_size", 1: "sequence_length"},
        **{
            f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"}
            for i in range(model.config.text_config.num_hidden_layers)
            for key in ["key", "value"]
        },
    }
)

# Export vision model
VISION_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, VISION_MODEL_NAME)
torch.onnx.export(
    vision_model,
    args=(torch.randn(2, 3, 224, 224),),  # Dummy input
    f=VISION_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=['pixel_values'],
    output_names=['image_features'],
    dynamic_axes={'pixel_values': {0: 'batch_size'}, 'image_features': {0: 'batch_size'}}
)

# Export embedding model (optional)
embed_layer = model.language_model.model.embed_tokens
input_ids = torch.randint(0, embed_layer.num_embeddings, (batch_size, sequence_length))
EMBED_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, EMBED_MODEL_NAME)
torch.onnx.export(
    embed_layer,
    args=(input_ids,),
    f=EMBED_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['inputs_embeds'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'inputs_embeds': {0: 'batch_size', 1: 'sequence_length'}
    },
)

# Post-processing and Upload to Hugging Face
from huggingface_hub import Repository

# Create a repo on Hugging Face (if it doesn't exist)
repo_name = f"paligemma-onnx-{model_id.split('/')[-1]}"
api = HfApi()
api.create_repo(repo_name=repo_name, exist_ok=True)

# Initialize a local repo for pushing to Hugging Face
repo = Repository(local_dir=repo_name, clone_from=repo_name)

# Move models to repo
os.makedirs(os.path.join(repo_name, "onnx"), exist_ok=True)
for model_file in [TEXT_MODEL_NAME, VISION_MODEL_NAME, EMBED_MODEL_NAME]:
    model_file_path = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, model_file)
    os.rename(model_file_path, os.path.join(repo_name, "onnx", model_file))

# Copy additional files (e.g., config and tokenizer) to the repo
os.makedirs(os.path.join(repo_name, "config"), exist_ok=True)
os.rename(os.path.join(OUTPUT_FOLDER, "config.json"), os.path.join(repo_name, "config", "config.json"))
os.rename(os.path.join(OUTPUT_FOLDER, "tokenizer.json"), os.path.join(repo_name, "tokenizer.json"))

# Push to Hugging Face Model Hub
repo.push_to_hub(commit_message="Upload ONNX models")

# Clean up temporary files
import shutil
shutil.rmtree(TEMP_MODEL_OUTPUT_FOLDER)