helloxm commited on
Commit
063fb27
·
verified ·
1 Parent(s): 145af7c

Upload the updated colab file

Browse files
Files changed (1) hide show
  1. deepseek_tflite.ipynb +307 -0
deepseek_tflite.ipynb ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "#Install dependencies"
21
+ ],
22
+ "metadata": {
23
+ "id": "39AMoCOa1ckc"
24
+ }
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "source": [
29
+ "!pip install ai-edge-litert-nightly"
30
+ ],
31
+ "metadata": {
32
+ "id": "43tAeO0AZ7zp"
33
+ },
34
+ "execution_count": null,
35
+ "outputs": []
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "source": [
40
+ "from ai_edge_litert import interpreter as interpreter_lib\n",
41
+ "from transformers import AutoTokenizer\n",
42
+ "import numpy as np\n",
43
+ "from collections.abc import Sequence\n",
44
+ "import sys"
45
+ ],
46
+ "metadata": {
47
+ "id": "i6PMkMVBPr1p"
48
+ },
49
+ "execution_count": null,
50
+ "outputs": []
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "source": [
55
+ "# Download model files"
56
+ ],
57
+ "metadata": {
58
+ "id": "K5okZCTgYpUd"
59
+ }
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "source": [
64
+ "from huggingface_hub import hf_hub_download\n",
65
+ "\n",
66
+ "model_path = hf_hub_download(repo_id=\"litert-community/DeepSeek-R1-Distill-Qwen-1.5B\", filename=\"deepseek_q8_seq128_ekv1280.tflite\")"
67
+ ],
68
+ "metadata": {
69
+ "id": "3t47HAG2tvc3"
70
+ },
71
+ "execution_count": null,
72
+ "outputs": []
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "source": [
77
+ "# Create LiteRT interpreter and tokenizer"
78
+ ],
79
+ "metadata": {
80
+ "id": "n5Xa4s6XhWqk"
81
+ }
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "source": [
86
+ "interpreter = interpreter_lib.InterpreterWithCustomOps(\n",
87
+ " custom_op_registerers=[\"pywrap_genai_ops.GenAIOpsRegisterer\"],\n",
88
+ " model_path=model_path,\n",
89
+ " num_threads=2,\n",
90
+ " experimental_default_delegate_latest_features=True)\n",
91
+ "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")"
92
+ ],
93
+ "metadata": {
94
+ "id": "Rvdn3EIZhaQn"
95
+ },
96
+ "execution_count": null,
97
+ "outputs": []
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "source": [
102
+ "# Create pipeline with LiteRT models"
103
+ ],
104
+ "metadata": {
105
+ "id": "AM6rDABTXt2F"
106
+ }
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "source": [
111
+ "\n",
112
+ "class LiteRTLlmPipeline:\n",
113
+ "\n",
114
+ " def __init__(self, interpreter, tokenizer):\n",
115
+ " \"\"\"Initializes the pipeline.\"\"\"\n",
116
+ " self._interpreter = interpreter\n",
117
+ " self._tokenizer = tokenizer\n",
118
+ "\n",
119
+ " self._prefill_runner = None\n",
120
+ " self._decode_runner = self._interpreter.get_signature_runner(\"decode\")\n",
121
+ "\n",
122
+ "\n",
123
+ " def _init_prefill_runner(self, num_input_tokens: int):\n",
124
+ " \"\"\"Initializes all the variables related to the prefill runner.\n",
125
+ "\n",
126
+ " This method initializes the following variables:\n",
127
+ " - self._prefill_runner: The prefill runner based on the input size.\n",
128
+ " - self._max_seq_len: The maximum sequence length supported by the model.\n",
129
+ " - self._max_kv_cache_seq_len: The maximum sequence length supported by the\n",
130
+ " KV cache.\n",
131
+ " - self._num_layers: The number of layers in the model.\n",
132
+ "\n",
133
+ " Args:\n",
134
+ " num_input_tokens: The number of input tokens.\n",
135
+ " \"\"\"\n",
136
+ "\n",
137
+ " self._prefill_runner = self._get_prefill_runner(num_input_tokens)\n",
138
+ " # input_token_shape has shape (batch, max_seq_len)\n",
139
+ " input_token_shape = self._prefill_runner.get_input_details()[\"tokens\"][\n",
140
+ " \"shape\"\n",
141
+ " ]\n",
142
+ " if len(input_token_shape) == 1:\n",
143
+ " self._max_seq_len = input_token_shape[0]\n",
144
+ " else:\n",
145
+ " self._max_seq_len = input_token_shape[1]\n",
146
+ "\n",
147
+ " # kv cache input has shape [batch=1, seq_len, num_heads, dim].\n",
148
+ " kv_cache_shape = self._prefill_runner.get_input_details()[\"kv_cache_k_0\"][\n",
149
+ " \"shape\"\n",
150
+ " ]\n",
151
+ " self._max_kv_cache_seq_len = kv_cache_shape[1]\n",
152
+ "\n",
153
+ " # The two arguments excluded are `tokens` and `input_pos`. Dividing by 2\n",
154
+ " # because each layer has key and value caches.\n",
155
+ " self._num_layers = (\n",
156
+ " len(self._prefill_runner.get_input_details().keys()) - 2\n",
157
+ " ) // 2\n",
158
+ "\n",
159
+ "\n",
160
+ " def _init_kv_cache(self) -> dict[str, np.ndarray]:\n",
161
+ " if self._prefill_runner is None:\n",
162
+ " raise ValueError(\"Prefill runner is not initialized.\")\n",
163
+ " kv_cache = {}\n",
164
+ " for i in range(self._num_layers):\n",
165
+ " kv_cache[f\"kv_cache_k_{i}\"] = np.zeros(\n",
166
+ " self._prefill_runner.get_input_details()[f\"kv_cache_k_{i}\"][\"shape\"],\n",
167
+ " dtype=np.float32,\n",
168
+ " )\n",
169
+ " kv_cache[f\"kv_cache_v_{i}\"] = np.zeros(\n",
170
+ " self._prefill_runner.get_input_details()[f\"kv_cache_v_{i}\"][\"shape\"],\n",
171
+ " dtype=np.float32,\n",
172
+ " )\n",
173
+ " return kv_cache\n",
174
+ "\n",
175
+ " def _get_prefill_runner(self, num_input_tokens: int) :\n",
176
+ " \"\"\"Gets the prefill runner with the best suitable input size.\n",
177
+ "\n",
178
+ " Args:\n",
179
+ " num_input_tokens: The number of input tokens.\n",
180
+ "\n",
181
+ " Returns:\n",
182
+ " The prefill runner with the smallest input size.\n",
183
+ " \"\"\"\n",
184
+ " best_signature = None\n",
185
+ " delta = sys.maxsize\n",
186
+ " max_prefill_len = -1\n",
187
+ " for key in self._interpreter.get_signature_list().keys():\n",
188
+ " if \"prefill\" not in key:\n",
189
+ " continue\n",
190
+ " input_pos = self._interpreter.get_signature_runner(key).get_input_details()[\n",
191
+ " \"input_pos\"\n",
192
+ " ]\n",
193
+ " # input_pos[\"shape\"] has shape (max_seq_len, )\n",
194
+ " seq_size = input_pos[\"shape\"][0]\n",
195
+ " max_prefill_len = max(max_prefill_len, seq_size)\n",
196
+ " if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:\n",
197
+ " delta = seq_size - num_input_tokens\n",
198
+ " best_signature = key\n",
199
+ " if best_signature is None:\n",
200
+ " raise ValueError(\n",
201
+ " \"The largest prefill length supported is %d, but we have %d number of input tokens\"\n",
202
+ " %(max_prefill_len, num_input_tokens)\n",
203
+ " )\n",
204
+ " return self._interpreter.get_signature_runner(best_signature)\n",
205
+ "\n",
206
+ " def _greedy_sampler(self, logits: np.ndarray) -> int:\n",
207
+ " return int(np.argmax(logits))\n",
208
+ "\n",
209
+ " def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:\n",
210
+ " messages=[{ 'role': 'user', 'content': prompt}]\n",
211
+ " token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)\n",
212
+ " # Initialize the prefill runner with the suitable input size.\n",
213
+ " self._init_prefill_runner(len(token_ids))\n",
214
+ "\n",
215
+ " actual_max_decode_steps = self._max_kv_cache_seq_len - len(token_ids)\n",
216
+ " if max_decode_steps is not None:\n",
217
+ " actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)\n",
218
+ "\n",
219
+ " input_token_ids = [0] * self._max_seq_len\n",
220
+ " input_token_ids[:len(token_ids)] = token_ids\n",
221
+ " model_inputs = self._init_kv_cache()\n",
222
+ " model_inputs.update({\n",
223
+ " \"tokens\": np.asarray([input_token_ids], dtype=np.int32),\n",
224
+ " \"input_pos\": np.arange(self._max_seq_len, dtype=np.int32),\n",
225
+ " })\n",
226
+ " decode_text = []\n",
227
+ " decode_step = 0\n",
228
+ " print('Running prefill')\n",
229
+ " for step in range(actual_max_decode_steps+1):\n",
230
+ " signature_runner = self._prefill_runner if step == 0 else self._decode_runner\n",
231
+ " model_outputs = signature_runner(**model_inputs)\n",
232
+ " # At prefill stage, output logits has shape (batch=1, seq_size, vocab_size)\n",
233
+ " # At decode stage, output logits has shape (batch=1, 1, vocab_size).\n",
234
+ " selected_logit = len(token_ids)-1 if step == 0 else 0\n",
235
+ " logits = model_outputs.pop(\"logits\")[0][selected_logit]\n",
236
+ "\n",
237
+ " if step == 0:\n",
238
+ " print('Running decode')\n",
239
+ "\n",
240
+ " # Decode text output.\n",
241
+ " next_token = self._greedy_sampler(logits)\n",
242
+ " if next_token == self._tokenizer.eos_token_id:\n",
243
+ " break\n",
244
+ " decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=False))\n",
245
+ " print(decode_text[-1], end='', flush=True)\n",
246
+ " # The rest of the outputs is the updated kv cache.\n",
247
+ " model_inputs = model_outputs\n",
248
+ " model_inputs.update({\n",
249
+ " \"tokens\": np.array([[next_token]], dtype=np.int32),\n",
250
+ " \"input_pos\": np.array([decode_step + len(token_ids)], dtype=np.int32),})\n",
251
+ " decode_step += 1\n",
252
+ "\n",
253
+ "\n",
254
+ "\n",
255
+ " print() # print a new line at the end.\n",
256
+ " return ''.join(decode_text)\n"
257
+ ],
258
+ "metadata": {
259
+ "id": "UBSGrHrM4ANm"
260
+ },
261
+ "execution_count": null,
262
+ "outputs": []
263
+ },
264
+ {
265
+ "cell_type": "markdown",
266
+ "source": [
267
+ "# Generate text from model"
268
+ ],
269
+ "metadata": {
270
+ "id": "dASKx_JtYXwe"
271
+ }
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "source": [
276
+ "# Disclaimer: Model performance demonstrated with the Python API in this notebook is not representative of performance on a local device.\n",
277
+ "pipeline = LiteRTLlmPipeline(interpreter, tokenizer)"
278
+ ],
279
+ "metadata": {
280
+ "id": "AZhlDQWg61AL"
281
+ },
282
+ "execution_count": null,
283
+ "outputs": []
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "source": [
288
+ "prompt = \"what is 8 mod 5\"\n",
289
+ "output = pipeline.generate(prompt, max_decode_steps = None)"
290
+ ],
291
+ "metadata": {
292
+ "id": "wT9BIiATkjzL"
293
+ },
294
+ "execution_count": null,
295
+ "outputs": []
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "source": [],
300
+ "metadata": {
301
+ "id": "GNzDBxDFEuAJ"
302
+ },
303
+ "execution_count": null,
304
+ "outputs": []
305
+ }
306
+ ]
307
+ }