{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Let's export the trained model in ONNX and safetensors formats for compatibility with downstream inference engines. First, we'll define some variables." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "model_name = \"lightgpt-small\"\n", "checkpoint_path = \"./checkpoints/checkpoint.pt\"\n", "lora_path = None # \"./checkpoints/lora_instruction.pt\"\n", "exports_path = \"./exports\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we'll load the base model checkpoint into memory from disk." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[3], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodel\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GPT, GPTWithLoRA\n\u001b[1;32m 5\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(checkpoint_path, map_location\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, weights_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 7\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGPT\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheckpoint\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel_args\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m model \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcompile(model)\n\u001b[1;32m 11\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n", "\u001b[0;31mTypeError\u001b[0m: GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'" ] } ], "source": [ "import torch\n", "\n", "from model import GPT, GPTWithLoRA\n", "\n", "checkpoint = torch.load(checkpoint_path, map_location=\"cpu\", weights_only=True)\n", "\n", "model = GPT(**checkpoint[\"model_args\"])\n", "\n", "model = torch.compile(model)\n", "\n", "model.load_state_dict(checkpoint[\"model\"])\n", "\n", "print(\"Base checkpoint loaded successfully\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we'll load any LoRA checkpoints we wish to incorporate into the exported model." ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "if lora_path != None:\n", " checkpoint = torch.load(lora_path, map_location=\"cpu\", weights_only=True)\n", "\n", " model = GPTWithLoRA(model, **checkpoint[\"lora_args\"])\n", "\n", " model = torch.compile(model)\n", "\n", " model.load_state_dict(checkpoint[\"lora\"], strict=False)\n", "\n", " model.merge_lora_parameters()\n", "\n", " print(\"LoRA checkpoint loaded successfully\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, export the model in Safetensors format." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model saved to ./exports/lightgpt-small-turbo.safetensors\n" ] } ], "source": [ "from os import path\n", "\n", "from safetensors.torch import save_model\n", "\n", "safetensors_path = path.join(exports_path, f\"{model_name}.safetensors\")\n", "\n", "save_model(model, safetensors_path)\n", "\n", "print(f\"Model saved to {safetensors_path}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For ONNX format we'll use TorchDynamo to trace the FX Graph of our model using some example data and then translate the intermediate representation to ONNX format." ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[torch.onnx] Obtain model graph for `OptimizedModule([...]` with `torch.export.export`...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "W0108 18:27:01.430000 5473 torch/onnx/_internal/exporter/_registration.py:73] torchvision is not installed. Skipping torchvision::nms\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[torch.onnx] Obtain model graph for `OptimizedModule([...]` with `torch.export.export`... ✅\n", "[torch.onnx] Translate the graph into ONNX...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "W0108 18:27:04.197000 5473 torch/onnx/_internal/exporter/_core.py:848] Skipping constant argument ConstantArgument(name='', value=None)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[torch.onnx] Translate the graph into ONNX... ✅\n", "Model saved to ./exports/lightgpt-small-turbo.onnx\n" ] } ], "source": [ "from torch.onnx import export\n", "\n", "example_input = torch.randint(0, model.vocabulary_size - 1, (1, model.block_size))\n", "\n", "model.eval() # Turn off dropout and other train-time operations\n", "\n", "example_output, _ = model(example_input)\n", "\n", "onnx_path = path.join(exports_path, f\"{model_name}.onnx\")\n", "\n", "export(\n", " model,\n", " example_input,\n", " onnx_path,\n", " input_names=[\"input_tokens\", \"labels\"],\n", " output_names=[\"logits\"],\n", " dynamo=True,\n", ")\n", "\n", "print(f\"Model saved to {onnx_path}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can verify the ONNX model with the ONNX API." ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looks OK\n" ] } ], "source": [ "import onnx\n", "\n", "onnx_model = onnx.load(onnx_path)\n", "\n", "onnx.checker.check_model(onnx_model)\n", "\n", "print(\"Looks OK\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly, let's compare the output of PyTorch with the ONNX runtime to see if they are the same." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'onnx_path' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtesting\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m assert_allclose\n\u001b[0;32m----> 7\u001b[0m session \u001b[38;5;241m=\u001b[39m onnxruntime\u001b[38;5;241m.\u001b[39mInferenceSession(\u001b[43monnx_path\u001b[49m, providers\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCPUExecutionProvider\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m onnx_input \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_tokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: example_input\u001b[38;5;241m.\u001b[39mnumpy()}\n\u001b[1;32m 11\u001b[0m output \u001b[38;5;241m=\u001b[39m session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, onnx_input)\n", "\u001b[0;31mNameError\u001b[0m: name 'onnx_path' is not defined" ] } ], "source": [ "import onnxruntime\n", "\n", "import numpy as np\n", "\n", "from numpy.testing import assert_allclose\n", "\n", "session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n", "\n", "onnx_input = {\"input_tokens\": example_input.numpy()}\n", "\n", "output = session.run(None, onnx_input)\n", "\n", "onnx_output = output[0]\n", "pytorch_output = np.array(example_output.detach())\n", "\n", "assert_allclose(pytorch_output, onnx_output, rtol=1e-2, atol=1e-03)\n", "\n", "print(\"Looking good\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }