#Install dependencies

In [1]:
!pip install ai-edge-litert



In [2]:
from ai_edge_litert import interpreter as interpreter_lib
from transformers import AutoTokenizer
import numpy as np
from collections.abc import Sequence
import sys

# Download model files

In [3]:
from huggingface_hub import hf_hub_download

model_path = hf_hub_download(repo_id="litert-community/DeepSeek-R1-Distill-Qwen-1.5B", filename="deepseek_q8_seq128_ekv1280.tflite")

# Create LiteRT interpreter and tokenizer

In [4]:
interpreter = interpreter_lib.InterpreterWithCustomOps(
 custom_op_registerers=["pywrap_genai_ops.GenAIOpsRegisterer"],
 model_path=model_path,
 num_threads=2,
 experimental_default_delegate_latest_features=True)
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")

# Create pipeline with LiteRT models

In [7]:


class LiteRTLlmPipeline:

 def __init__(self, interpreter, tokenizer):
 """Initializes the pipeline."""
 self._interpreter = interpreter
 self._tokenizer = tokenizer

 self._prefill_runner = None
 self._decode_runner = self._interpreter.get_signature_runner("decode")


 def _init_prefill_runner(self, num_input_tokens: int):
 """Initializes all the variables related to the prefill runner.

 This method initializes the following variables:
 - self._prefill_runner: The prefill runner based on the input size.
 - self._max_seq_len: The maximum sequence length supported by the model.
 - self._max_kv_cache_seq_len: The maximum sequence length supported by the
 KV cache.
 - self._num_layers: The number of layers in the model.

 Args:
 num_input_tokens: The number of input tokens.
 """
 if not self._interpreter:
 raise ValueError("Interpreter is not initialized.")

 # Prefill runner related variables will be initialized in `predict_text` and
 # `compute_log_likelihood`.
 self._prefill_runner = self._get_prefill_runner(num_input_tokens)
 # input_token_shape has shape (batch, max_seq_len)
 input_token_shape = self._prefill_runner.get_input_details()["tokens"][
 "shape"
 ]
 if len(input_token_shape) == 1:
 self._max_seq_len = input_token_shape[0]
 else:
 self._max_seq_len = input_token_shape[1]

 # kv cache input has shape [batch=1, seq_len, num_heads, dim].
 kv_cache_shape = self._prefill_runner.get_input_details()["kv_cache_k_0"][
 "shape"
 ]
 self._max_kv_cache_seq_len = kv_cache_shape[1]

 # The two arguments excluded are `tokens` and `input_pos`. Dividing by 2
 # because each layer has key and value caches.
 self._num_layers = (
 len(self._prefill_runner.get_input_details().keys()) - 2
 ) // 2


 def _init_kv_cache(self) -> dict[str, np.ndarray]:
 if self._prefill_runner is None:
 raise ValueError("Prefill runner is not initialized.")
 kv_cache = {}
 for i in range(self._num_layers):
 kv_cache[f"kv_cache_k_{i}"] = np.zeros(
 self._prefill_runner.get_input_details()[f"kv_cache_k_{i}"]["shape"],
 dtype=np.float32,
 )
 kv_cache[f"kv_cache_v_{i}"] = np.zeros(
 self._prefill_runner.get_input_details()[f"kv_cache_v_{i}"]["shape"],
 dtype=np.float32,
 )
 return kv_cache

 def _get_prefill_runner(self, num_input_tokens: int) :
 """Gets the prefill runner with the best suitable input size.

 Args:
 num_input_tokens: The number of input tokens.

 Returns:
 The prefill runner with the smallest input size.
 """
 best_signature = None
 delta = sys.maxsize
 max_prefill_len = -1
 for key in self._interpreter.get_signature_list().keys():
 if "prefill" not in key:
 continue
 input_pos = self._interpreter.get_signature_runner(key).get_input_details()[
 "input_pos"
 ]
 # input_pos["shape"] has shape (max_seq_len, )
 seq_size = input_pos["shape"][0]
 max_prefill_len = max(max_prefill_len, seq_size)
 if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:
 delta = seq_size - num_input_tokens
 best_signature = key
 if best_signature is None:
 raise ValueError(
 "The largest prefill length supported is %d, but we have %d number of input tokens"
 %(max_prefill_len, num_input_tokens)
 )
 return self._interpreter.get_signature_runner(best_signature)

 def _run_prefill(
 self, prefill_token_ids: Sequence[int],
 ) -> dict[str, np.ndarray]:
 """Runs prefill and returns the kv cache.

 Args:
 prefill_token_ids: The token ids of the prefill input.

 Returns:
 The updated kv cache.
 """
 if not self._prefill_runner:
 raise ValueError("Prefill runner is not initialized.")
 prefill_token_length = len(prefill_token_ids)
 if prefill_token_length == 0:
 return self._init_kv_cache()

 # Prepare the input to be [1, max_seq_len].
 input_token_ids = [0] * self._max_seq_len
 input_token_ids[:prefill_token_length] = prefill_token_ids
 input_token_ids = np.asarray(input_token_ids, dtype=np.int32)
 input_token_ids = np.expand_dims(input_token_ids, axis=0)

 # Prepare the input position to be [max_seq_len].
 input_pos = [0] * self._max_seq_len
 input_pos[:prefill_token_length] = range(prefill_token_length)
 input_pos = np.asarray(input_pos, dtype=np.int32)

 # Initialize kv cache.
 prefill_inputs = self._init_kv_cache()
 prefill_inputs.update({
 "tokens": input_token_ids,
 "input_pos": input_pos,
 })
 prefill_outputs = self._prefill_runner(**prefill_inputs)
 if "logits" in prefill_outputs:
 # Prefill outputs includes logits and kv cache. We only output kv cache.
 prefill_outputs.pop("logits")

 return prefill_outputs

 def _greedy_sampler(self, logits: np.ndarray) -> int:
 return int(np.argmax(logits))


 def _run_decode(
 self,
 start_pos: int,
 start_token_id: int,
 kv_cache: dict[str, np.ndarray],
 max_decode_steps: int,
 ) -> str:
 """Runs decode and outputs the token ids from greedy sampler.

 Args:
 start_pos: The position of the first token of the decode input.
 start_token_id: The token id of the first token of the decode input.
 kv_cache: The kv cache from the prefill.
 max_decode_steps: The max decode steps.

 Returns:
 The token ids from the greedy sampler.
 """
 next_pos = start_pos
 next_token = start_token_id
 decode_text = []
 decode_inputs = kv_cache

 for _ in range(max_decode_steps):
 decode_inputs.update({
 "tokens": np.array([[next_token]], dtype=np.int32),
 "input_pos": np.array([next_pos], dtype=np.int32),
 })
 decode_outputs = self._decode_runner(**decode_inputs)
 # Output logits has shape (batch=1, 1, vocab_size). We only take the first
 # element.
 logits = decode_outputs.pop("logits")[0][0]
 next_token = self._greedy_sampler(logits)
 if next_token == self._tokenizer.eos_token_id:
 break
 decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=False))
 print(decode_text[-1], end='', flush=True)
 # Decode outputs includes logits and kv cache. We already poped out
 # logits, so the rest is kv cache. We pass the updated kv cache as input
 # to the next decode step.
 decode_inputs = decode_outputs
 next_pos += 1

 print() # print a new line at the end.
 return ''.join(decode_text)

 def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:
 messages=[{ 'role': 'user', 'content': prompt}]
 token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
 # Initialize the prefill runner with the suitable input size.
 self._init_prefill_runner(len(token_ids))

 # Run prefill.
 # Prefill up to the seond to the last token of the prompt, because the last
 # token of the prompt will be used to bootstrap decode.
 prefill_token_length = len(token_ids) - 1

 print('Running prefill')
 kv_cache = self._run_prefill(token_ids[:prefill_token_length])
 # Run decode.
 print('Running decode')
 actual_max_decode_steps = self._max_kv_cache_seq_len - prefill_token_length - 1
 if max_decode_steps is not None:
 actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)
 decode_text = self._run_decode(
 prefill_token_length,
 token_ids[prefill_token_length],
 kv_cache,
 actual_max_decode_steps,
 )
 return decode_text

# Generate text from model

In [8]:
# Disclaimer: Model performance demonstrated with the Python API in this notebook is not representative of performance on a local device.
pipeline = LiteRTLlmPipeline(interpreter, tokenizer)

In [None]:
prompt = "what is 8 mod 5"
output = pipeline.generate(prompt, max_decode_steps = None)