{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Let's export the trained model in ONNX and safetensors formats for compatibility with downstream inference engines. First, we'll define some variables." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model_name = \"lightgpt-small\"\n", "checkpoint_path = \"./checkpoints/checkpoint.pt\"\n", "lora_path = None # \"./checkpoints/lora_instruction.pt\"\n", "exports_path = \"./exports\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we'll load the base model checkpoint into memory from disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Base checkpoint loaded successfully\n" ] } ], "source": [ "import torch\n", "\n", "from model import LightGPT\n", "\n", "checkpoint = torch.load(checkpoint_path, map_location=\"cpu\", weights_only=True)\n", "\n", "model = LightGPT(**checkpoint[\"model_args\"])\n", "\n", "model = torch.compile(model)\n", "\n", "model.load_state_dict(checkpoint[\"model\"])\n", "\n", "print(\"Base checkpoint loaded successfully\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we'll load any LoRA checkpoints we wish to incorporate into the exported model." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from model import LightGPTInstruct\n", "\n", "if lora_path != None:\n", " checkpoint = torch.load(lora_path, map_location=\"cpu\", weights_only=True)\n", "\n", " model = LightGPTInstruct(model, **checkpoint[\"lora_args\"])\n", "\n", " model = torch.compile(model)\n", "\n", " model.load_state_dict(checkpoint[\"lora\"], strict=False)\n", "\n", " model.merge_lora_parameters()\n", "\n", " print(\"LoRA checkpoint loaded successfully\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, export the model in Safetensors format." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model saved to ./exports/lightgpt-small.safetensors\n" ] } ], "source": [ "from os import path\n", "\n", "from safetensors.torch import save_model\n", "\n", "safetensors_path = path.join(exports_path, f\"{model_name}.safetensors\")\n", "\n", "save_model(model, safetensors_path)\n", "\n", "print(f\"Model saved to {safetensors_path}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For ONNX format we'll use TorchDynamo to trace the FX Graph of our model using some example data and then translate the intermediate representation to ONNX format." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/andrew/Workspace/LightGPT/.venv/lib/python3.12/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Applied 73 of general pattern rewrite rules.\n", "Model saved to ./exports/lightgpt-small.onnx\n" ] } ], "source": [ "from model import ONNXModel\n", "\n", "from torch.onnx import dynamo_export, ExportOptions\n", "\n", "example_input = torch.randint(0, model.vocabulary_size - 1, (1, 1024))\n", "\n", "model = ONNXModel(model) # Nicer inferencing API\n", "\n", "model.eval() # Turn off dropout and other train-time operations\n", "\n", "export_options = ExportOptions(\n", " dynamic_shapes=True\n", ") # Necessary for variable batch and sequence lengths\n", "\n", "onnx_model = dynamo_export(model, example_input, export_options=export_options)\n", "\n", "onnx_path = path.join(exports_path, f\"{model_name}.onnx\")\n", "\n", "onnx_model.save(onnx_path)\n", "\n", "print(f\"Model saved to {onnx_path}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly, let's compare the output of PyTorch with the ONNX runtime to see if they are the same." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looks good!\n" ] } ], "source": [ "import onnxruntime\n", "\n", "from numpy.testing import assert_allclose\n", "\n", "pytorch_logits = model(example_input).detach().numpy()\n", "\n", "session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n", "\n", "onnx_input = {\"l_x_\": example_input.numpy()}\n", "\n", "onnx_logits = session.run(None, onnx_input)\n", "\n", "onnx_logits = onnx_logits[0]\n", "\n", "assert_allclose(pytorch_logits, onnx_logits, rtol=1e-2, atol=1e-03)\n", "\n", "print(\"Looks good!\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }