{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "#Install dependencies" ], "metadata": { "id": "39AMoCOa1ckc" } }, { "cell_type": "code", "source": [ "!pip install ai-edge-litert" ], "metadata": { "id": "43tAeO0AZ7zp", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7ce4d1ef-7d6b-4855-b73b-22482e3c693d" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: ai-edge-litert in /usr/local/lib/python3.11/dist-packages (1.1.2)\n", "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.11/dist-packages (from ai-edge-litert) (25.2.10)\n", "Requirement already satisfied: numpy>=1.23.2 in /usr/local/lib/python3.11/dist-packages (from ai-edge-litert) (1.26.4)\n" ] } ] }, { "cell_type": "code", "source": [ "from ai_edge_litert import interpreter as interpreter_lib\n", "from transformers import AutoTokenizer\n", "import numpy as np\n", "from collections.abc import Sequence\n", "import sys" ], "metadata": { "id": "i6PMkMVBPr1p" }, "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Download model files" ], "metadata": { "id": "K5okZCTgYpUd" } }, { "cell_type": "code", "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "model_path = hf_hub_download(repo_id=\"litert-community/DeepSeek-R1-Distill-Qwen-1.5B\", filename=\"deepseek_q8_seq128_ekv1280.tflite\")" ], "metadata": { "id": "3t47HAG2tvc3" }, "execution_count": 3, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Create LiteRT interpreter and tokenizer" ], "metadata": { "id": "n5Xa4s6XhWqk" } }, { "cell_type": "code", "source": [ "interpreter = interpreter_lib.InterpreterWithCustomOps(\n", " custom_op_registerers=[\"pywrap_genai_ops.GenAIOpsRegisterer\"],\n", " model_path=model_path,\n", " num_threads=2,\n", " experimental_default_delegate_latest_features=True)\n", "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")" ], "metadata": { "id": "Rvdn3EIZhaQn" }, "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Create pipeline with LiteRT models" ], "metadata": { "id": "AM6rDABTXt2F" } }, { "cell_type": "code", "source": [ "\n", "\n", "class LiteRTLlmPipeline:\n", "\n", " def __init__(self, interpreter, tokenizer):\n", " \"\"\"Initializes the pipeline.\"\"\"\n", " self._interpreter = interpreter\n", " self._tokenizer = tokenizer\n", "\n", " self._prefill_runner = None\n", " self._decode_runner = self._interpreter.get_signature_runner(\"decode\")\n", "\n", "\n", " def _init_prefill_runner(self, num_input_tokens: int):\n", " \"\"\"Initializes all the variables related to the prefill runner.\n", "\n", " This method initializes the following variables:\n", " - self._prefill_runner: The prefill runner based on the input size.\n", " - self._max_seq_len: The maximum sequence length supported by the model.\n", " - self._max_kv_cache_seq_len: The maximum sequence length supported by the\n", " KV cache.\n", " - self._num_layers: The number of layers in the model.\n", "\n", " Args:\n", " num_input_tokens: The number of input tokens.\n", " \"\"\"\n", " if not self._interpreter:\n", " raise ValueError(\"Interpreter is not initialized.\")\n", "\n", " # Prefill runner related variables will be initialized in `predict_text` and\n", " # `compute_log_likelihood`.\n", " self._prefill_runner = self._get_prefill_runner(num_input_tokens)\n", " # input_token_shape has shape (batch, max_seq_len)\n", " input_token_shape = self._prefill_runner.get_input_details()[\"tokens\"][\n", " \"shape\"\n", " ]\n", " if len(input_token_shape) == 1:\n", " self._max_seq_len = input_token_shape[0]\n", " else:\n", " self._max_seq_len = input_token_shape[1]\n", "\n", " # kv cache input has shape [batch=1, seq_len, num_heads, dim].\n", " kv_cache_shape = self._prefill_runner.get_input_details()[\"kv_cache_k_0\"][\n", " \"shape\"\n", " ]\n", " self._max_kv_cache_seq_len = kv_cache_shape[1]\n", "\n", " # The two arguments excluded are `tokens` and `input_pos`. Dividing by 2\n", " # because each layer has key and value caches.\n", " self._num_layers = (\n", " len(self._prefill_runner.get_input_details().keys()) - 2\n", " ) // 2\n", "\n", "\n", " def _init_kv_cache(self) -> dict[str, np.ndarray]:\n", " if self._prefill_runner is None:\n", " raise ValueError(\"Prefill runner is not initialized.\")\n", " kv_cache = {}\n", " for i in range(self._num_layers):\n", " kv_cache[f\"kv_cache_k_{i}\"] = np.zeros(\n", " self._prefill_runner.get_input_details()[f\"kv_cache_k_{i}\"][\"shape\"],\n", " dtype=np.float32,\n", " )\n", " kv_cache[f\"kv_cache_v_{i}\"] = np.zeros(\n", " self._prefill_runner.get_input_details()[f\"kv_cache_v_{i}\"][\"shape\"],\n", " dtype=np.float32,\n", " )\n", " return kv_cache\n", "\n", " def _get_prefill_runner(self, num_input_tokens: int) :\n", " \"\"\"Gets the prefill runner with the best suitable input size.\n", "\n", " Args:\n", " num_input_tokens: The number of input tokens.\n", "\n", " Returns:\n", " The prefill runner with the smallest input size.\n", " \"\"\"\n", " best_signature = None\n", " delta = sys.maxsize\n", " max_prefill_len = -1\n", " for key in self._interpreter.get_signature_list().keys():\n", " if \"prefill\" not in key:\n", " continue\n", " input_pos = self._interpreter.get_signature_runner(key).get_input_details()[\n", " \"input_pos\"\n", " ]\n", " # input_pos[\"shape\"] has shape (max_seq_len, )\n", " seq_size = input_pos[\"shape\"][0]\n", " max_prefill_len = max(max_prefill_len, seq_size)\n", " if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:\n", " delta = seq_size - num_input_tokens\n", " best_signature = key\n", " if best_signature is None:\n", " raise ValueError(\n", " \"The largest prefill length supported is %d, but we have %d number of input tokens\"\n", " %(max_prefill_len, num_input_tokens)\n", " )\n", " return self._interpreter.get_signature_runner(best_signature)\n", "\n", " def _run_prefill(\n", " self, prefill_token_ids: Sequence[int],\n", " ) -> dict[str, np.ndarray]:\n", " \"\"\"Runs prefill and returns the kv cache.\n", "\n", " Args:\n", " prefill_token_ids: The token ids of the prefill input.\n", "\n", " Returns:\n", " The updated kv cache.\n", " \"\"\"\n", " if not self._prefill_runner:\n", " raise ValueError(\"Prefill runner is not initialized.\")\n", " prefill_token_length = len(prefill_token_ids)\n", " if prefill_token_length == 0:\n", " return self._init_kv_cache()\n", "\n", " # Prepare the input to be [1, max_seq_len].\n", " input_token_ids = [0] * self._max_seq_len\n", " input_token_ids[:prefill_token_length] = prefill_token_ids\n", " input_token_ids = np.asarray(input_token_ids, dtype=np.int32)\n", " input_token_ids = np.expand_dims(input_token_ids, axis=0)\n", "\n", " # Prepare the input position to be [max_seq_len].\n", " input_pos = [0] * self._max_seq_len\n", " input_pos[:prefill_token_length] = range(prefill_token_length)\n", " input_pos = np.asarray(input_pos, dtype=np.int32)\n", "\n", " # Initialize kv cache.\n", " prefill_inputs = self._init_kv_cache()\n", " prefill_inputs.update({\n", " \"tokens\": input_token_ids,\n", " \"input_pos\": input_pos,\n", " })\n", " prefill_outputs = self._prefill_runner(**prefill_inputs)\n", " if \"logits\" in prefill_outputs:\n", " # Prefill outputs includes logits and kv cache. We only output kv cache.\n", " prefill_outputs.pop(\"logits\")\n", "\n", " return prefill_outputs\n", "\n", " def _greedy_sampler(self, logits: np.ndarray) -> int:\n", " return int(np.argmax(logits))\n", "\n", "\n", " def _run_decode(\n", " self,\n", " start_pos: int,\n", " start_token_id: int,\n", " kv_cache: dict[str, np.ndarray],\n", " max_decode_steps: int,\n", " ) -> str:\n", " \"\"\"Runs decode and outputs the token ids from greedy sampler.\n", "\n", " Args:\n", " start_pos: The position of the first token of the decode input.\n", " start_token_id: The token id of the first token of the decode input.\n", " kv_cache: The kv cache from the prefill.\n", " max_decode_steps: The max decode steps.\n", "\n", " Returns:\n", " The token ids from the greedy sampler.\n", " \"\"\"\n", " next_pos = start_pos\n", " next_token = start_token_id\n", " decode_text = []\n", " decode_inputs = kv_cache\n", "\n", " for _ in range(max_decode_steps):\n", " decode_inputs.update({\n", " \"tokens\": np.array([[next_token]], dtype=np.int32),\n", " \"input_pos\": np.array([next_pos], dtype=np.int32),\n", " })\n", " decode_outputs = self._decode_runner(**decode_inputs)\n", " # Output logits has shape (batch=1, 1, vocab_size). We only take the first\n", " # element.\n", " logits = decode_outputs.pop(\"logits\")[0][0]\n", " next_token = self._greedy_sampler(logits)\n", " if next_token == self._tokenizer.eos_token_id:\n", " break\n", " decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=False))\n", " print(decode_text[-1], end='', flush=True)\n", " # Decode outputs includes logits and kv cache. We already poped out\n", " # logits, so the rest is kv cache. We pass the updated kv cache as input\n", " # to the next decode step.\n", " decode_inputs = decode_outputs\n", " next_pos += 1\n", "\n", " print() # print a new line at the end.\n", " return ''.join(decode_text)\n", "\n", " def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:\n", " messages=[{ 'role': 'user', 'content': prompt}]\n", " token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)\n", " # Initialize the prefill runner with the suitable input size.\n", " self._init_prefill_runner(len(token_ids))\n", "\n", " # Run prefill.\n", " # Prefill up to the seond to the last token of the prompt, because the last\n", " # token of the prompt will be used to bootstrap decode.\n", " prefill_token_length = len(token_ids) - 1\n", "\n", " print('Running prefill')\n", " kv_cache = self._run_prefill(token_ids[:prefill_token_length])\n", " # Run decode.\n", " print('Running decode')\n", " actual_max_decode_steps = self._max_kv_cache_seq_len - prefill_token_length - 1\n", " if max_decode_steps is not None:\n", " actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)\n", " decode_text = self._run_decode(\n", " prefill_token_length,\n", " token_ids[prefill_token_length],\n", " kv_cache,\n", " actual_max_decode_steps,\n", " )\n", " return decode_text" ], "metadata": { "id": "UBSGrHrM4ANm" }, "execution_count": 7, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Generate text from model" ], "metadata": { "id": "dASKx_JtYXwe" } }, { "cell_type": "code", "source": [ "# Disclaimer: Model performance demonstrated with the Python API in this notebook is not representative of performance on a local device.\n", "pipeline = LiteRTLlmPipeline(interpreter, tokenizer)" ], "metadata": { "id": "AZhlDQWg61AL" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "prompt = \"what is 8 mod 5\"\n", "output = pipeline.generate(prompt, max_decode_steps = None)" ], "metadata": { "id": "wT9BIiATkjzL" }, "execution_count": null, "outputs": [] } ] }