Let's export the trained model in ONNX and safetensors formats for compatibility with downstream inference engines. First, we'll define some variables.

In [2]:
model_name = "lightgpt-small"
checkpoint_path = "./checkpoints/checkpoint.pt"
lora_path = None # "./checkpoints/lora_instruction.pt"
exports_path = "./exports"

Then, we'll load the base model checkpoint into memory from disk.

In [3]:
import torch

from model import GPT, GPTWithLoRA

checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

model = GPT(**checkpoint["model_args"])

model = torch.compile(model)

model.load_state_dict(checkpoint["model"])

print("Base checkpoint loaded successfully")

TypeError: GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'

Now, we'll load any LoRA checkpoints we wish to incorporate into the exported model.

In [58]:
if lora_path != None:
 checkpoint = torch.load(lora_path, map_location="cpu", weights_only=True)

 model = GPTWithLoRA(model, **checkpoint["lora_args"])

 model = torch.compile(model)

 model.load_state_dict(checkpoint["lora"], strict=False)

 model.merge_lora_parameters()

 print("LoRA checkpoint loaded successfully")

Now, export the model in Safetensors format.

In [59]:
from os import path

from safetensors.torch import save_model

safetensors_path = path.join(exports_path, f"{model_name}.safetensors")

save_model(model, safetensors_path)

print(f"Model saved to {safetensors_path}")

Model saved to ./exports/lightgpt-small-turbo.safetensors


For ONNX format we'll use TorchDynamo to trace the FX Graph of our model using some example data and then translate the intermediate representation to ONNX format.

In [86]:
from torch.onnx import export

example_input = torch.randint(0, model.vocabulary_size - 1, (1, model.block_size))

model.eval() # Turn off dropout and other train-time operations

example_output, _ = model(example_input)

onnx_path = path.join(exports_path, f"{model_name}.onnx")

export(
 model,
 example_input,
 onnx_path,
 input_names=["input_tokens", "labels"],
 output_names=["logits"],
 dynamo=True,
)

print(f"Model saved to {onnx_path}")

[torch.onnx] Obtain model graph for `OptimizedModule([...]` with `torch.export.export`...


W0108 18:27:01.430000 5473 torch/onnx/_internal/exporter/_registration.py:73] torchvision is not installed. Skipping torchvision::nms


[torch.onnx] Obtain model graph for `OptimizedModule([...]` with `torch.export.export`... ✅
[torch.onnx] Translate the graph into ONNX...


W0108 18:27:04.197000 5473 torch/onnx/_internal/exporter/_core.py:848] Skipping constant argument ConstantArgument(name='', value=None)


[torch.onnx] Translate the graph into ONNX... ✅
Model saved to ./exports/lightgpt-small-turbo.onnx


We can verify the ONNX model with the ONNX API.

In [87]:
import onnx

onnx_model = onnx.load(onnx_path)

onnx.checker.check_model(onnx_model)

print("Looks OK")

Looks OK


Lastly, let's compare the output of PyTorch with the ONNX runtime to see if they are the same.

In [None]:
import onnxruntime

import numpy as np

from numpy.testing import assert_allclose

session = onnxruntime.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

onnx_input = {"input_tokens": example_input.numpy()}

output = session.run(None, onnx_input)

onnx_output = output[0]
pytorch_output = np.array(example_output.detach())

assert_allclose(pytorch_output, onnx_output, rtol=1e-2, atol=1e-03)

print("Looking good")

NameError: name 'onnx_path' is not defined