Upload deepseek_tflite.ipynb

#7
by chenxugz - opened
Files changed (1) hide show
  1. deepseek_tflite.ipynb +130 -55
deepseek_tflite.ipynb CHANGED
@@ -26,13 +26,27 @@
26
  {
27
  "cell_type": "code",
28
  "source": [
29
- "!pip install ai-edge-litert-nightly"
30
  ],
31
  "metadata": {
32
- "id": "43tAeO0AZ7zp"
 
 
 
 
33
  },
34
- "execution_count": null,
35
- "outputs": []
 
 
 
 
 
 
 
 
 
 
36
  },
37
  {
38
  "cell_type": "code",
@@ -46,7 +60,7 @@
46
  "metadata": {
47
  "id": "i6PMkMVBPr1p"
48
  },
49
- "execution_count": null,
50
  "outputs": []
51
  },
52
  {
@@ -68,7 +82,7 @@
68
  "metadata": {
69
  "id": "3t47HAG2tvc3"
70
  },
71
- "execution_count": null,
72
  "outputs": []
73
  },
74
  {
@@ -93,7 +107,7 @@
93
  "metadata": {
94
  "id": "Rvdn3EIZhaQn"
95
  },
96
- "execution_count": null,
97
  "outputs": []
98
  },
99
  {
@@ -108,6 +122,7 @@
108
  {
109
  "cell_type": "code",
110
  "source": [
 
111
  "\n",
112
  "class LiteRTLlmPipeline:\n",
113
  "\n",
@@ -133,7 +148,11 @@
133
  " Args:\n",
134
  " num_input_tokens: The number of input tokens.\n",
135
  " \"\"\"\n",
 
 
136
  "\n",
 
 
137
  " self._prefill_runner = self._get_prefill_runner(num_input_tokens)\n",
138
  " # input_token_shape has shape (batch, max_seq_len)\n",
139
  " input_token_shape = self._prefill_runner.get_input_details()[\"tokens\"][\n",
@@ -203,62 +222,127 @@
203
  " )\n",
204
  " return self._interpreter.get_signature_runner(best_signature)\n",
205
  "\n",
206
- " def _greedy_sampler(self, logits: np.ndarray) -> int:\n",
207
- " return int(np.argmax(logits))\n",
 
 
208
  "\n",
209
- " def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:\n",
210
- " messages=[{ 'role': 'user', 'content': prompt}]\n",
211
- " token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)\n",
212
- " # Initialize the prefill runner with the suitable input size.\n",
213
- " self._init_prefill_runner(len(token_ids))\n",
214
  "\n",
215
- " actual_max_decode_steps = self._max_kv_cache_seq_len - len(token_ids)\n",
216
- " if max_decode_steps is not None:\n",
217
- " actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)\n",
 
 
 
 
 
218
  "\n",
 
219
  " input_token_ids = [0] * self._max_seq_len\n",
220
- " input_token_ids[:len(token_ids)] = token_ids\n",
221
- " model_inputs = self._init_kv_cache()\n",
222
- " model_inputs.update({\n",
223
- " \"tokens\": np.asarray([input_token_ids], dtype=np.int32),\n",
224
- " \"input_pos\": np.arange(self._max_seq_len, dtype=np.int32),\n",
 
 
 
 
 
 
 
 
 
225
  " })\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  " decode_text = []\n",
227
- " decode_step = 0\n",
228
- " print('Running prefill')\n",
229
- " for step in range(actual_max_decode_steps+1):\n",
230
- " signature_runner = self._prefill_runner if step == 0 else self._decode_runner\n",
231
- " model_outputs = signature_runner(**model_inputs)\n",
232
- " # At prefill stage, output logits has shape (batch=1, seq_size, vocab_size)\n",
233
- " # At decode stage, output logits has shape (batch=1, 1, vocab_size).\n",
234
- " selected_logit = len(token_ids)-1 if step == 0 else 0\n",
235
- " logits = model_outputs.pop(\"logits\")[0][selected_logit]\n",
236
- "\n",
237
- " if step == 0:\n",
238
- " print('Running decode')\n",
239
- "\n",
240
- " # Decode text output.\n",
241
  " next_token = self._greedy_sampler(logits)\n",
242
  " if next_token == self._tokenizer.eos_token_id:\n",
243
  " break\n",
244
  " decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=False))\n",
245
  " print(decode_text[-1], end='', flush=True)\n",
246
- " # The rest of the outputs is the updated kv cache.\n",
247
- " model_inputs = model_outputs\n",
248
- " model_inputs.update({\n",
249
- " \"tokens\": np.array([[next_token]], dtype=np.int32),\n",
250
- " \"input_pos\": np.array([decode_step + len(token_ids)], dtype=np.int32),})\n",
251
- " decode_step += 1\n",
 
 
252
  "\n",
 
 
 
 
 
253
  "\n",
 
 
 
 
254
  "\n",
255
- " print() # print a new line at the end.\n",
256
- " return ''.join(decode_text)\n"
 
 
 
 
 
 
 
 
 
 
 
 
257
  ],
258
  "metadata": {
259
  "id": "UBSGrHrM4ANm"
260
  },
261
- "execution_count": null,
262
  "outputs": []
263
  },
264
  {
@@ -279,7 +363,7 @@
279
  "metadata": {
280
  "id": "AZhlDQWg61AL"
281
  },
282
- "execution_count": null,
283
  "outputs": []
284
  },
285
  {
@@ -293,15 +377,6 @@
293
  },
294
  "execution_count": null,
295
  "outputs": []
296
- },
297
- {
298
- "cell_type": "code",
299
- "source": [],
300
- "metadata": {
301
- "id": "GNzDBxDFEuAJ"
302
- },
303
- "execution_count": null,
304
- "outputs": []
305
  }
306
  ]
307
  }
 
26
  {
27
  "cell_type": "code",
28
  "source": [
29
+ "!pip install ai-edge-litert"
30
  ],
31
  "metadata": {
32
+ "id": "43tAeO0AZ7zp",
33
+ "colab": {
34
+ "base_uri": "https://localhost:8080/"
35
+ },
36
+ "outputId": "7ce4d1ef-7d6b-4855-b73b-22482e3c693d"
37
  },
38
+ "execution_count": 1,
39
+ "outputs": [
40
+ {
41
+ "output_type": "stream",
42
+ "name": "stdout",
43
+ "text": [
44
+ "Requirement already satisfied: ai-edge-litert in /usr/local/lib/python3.11/dist-packages (1.1.2)\n",
45
+ "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.11/dist-packages (from ai-edge-litert) (25.2.10)\n",
46
+ "Requirement already satisfied: numpy>=1.23.2 in /usr/local/lib/python3.11/dist-packages (from ai-edge-litert) (1.26.4)\n"
47
+ ]
48
+ }
49
+ ]
50
  },
51
  {
52
  "cell_type": "code",
 
60
  "metadata": {
61
  "id": "i6PMkMVBPr1p"
62
  },
63
+ "execution_count": 2,
64
  "outputs": []
65
  },
66
  {
 
82
  "metadata": {
83
  "id": "3t47HAG2tvc3"
84
  },
85
+ "execution_count": 3,
86
  "outputs": []
87
  },
88
  {
 
107
  "metadata": {
108
  "id": "Rvdn3EIZhaQn"
109
  },
110
+ "execution_count": 4,
111
  "outputs": []
112
  },
113
  {
 
122
  {
123
  "cell_type": "code",
124
  "source": [
125
+ "\n",
126
  "\n",
127
  "class LiteRTLlmPipeline:\n",
128
  "\n",
 
148
  " Args:\n",
149
  " num_input_tokens: The number of input tokens.\n",
150
  " \"\"\"\n",
151
+ " if not self._interpreter:\n",
152
+ " raise ValueError(\"Interpreter is not initialized.\")\n",
153
  "\n",
154
+ " # Prefill runner related variables will be initialized in `predict_text` and\n",
155
+ " # `compute_log_likelihood`.\n",
156
  " self._prefill_runner = self._get_prefill_runner(num_input_tokens)\n",
157
  " # input_token_shape has shape (batch, max_seq_len)\n",
158
  " input_token_shape = self._prefill_runner.get_input_details()[\"tokens\"][\n",
 
222
  " )\n",
223
  " return self._interpreter.get_signature_runner(best_signature)\n",
224
  "\n",
225
+ " def _run_prefill(\n",
226
+ " self, prefill_token_ids: Sequence[int],\n",
227
+ " ) -> dict[str, np.ndarray]:\n",
228
+ " \"\"\"Runs prefill and returns the kv cache.\n",
229
  "\n",
230
+ " Args:\n",
231
+ " prefill_token_ids: The token ids of the prefill input.\n",
 
 
 
232
  "\n",
233
+ " Returns:\n",
234
+ " The updated kv cache.\n",
235
+ " \"\"\"\n",
236
+ " if not self._prefill_runner:\n",
237
+ " raise ValueError(\"Prefill runner is not initialized.\")\n",
238
+ " prefill_token_length = len(prefill_token_ids)\n",
239
+ " if prefill_token_length == 0:\n",
240
+ " return self._init_kv_cache()\n",
241
  "\n",
242
+ " # Prepare the input to be [1, max_seq_len].\n",
243
  " input_token_ids = [0] * self._max_seq_len\n",
244
+ " input_token_ids[:prefill_token_length] = prefill_token_ids\n",
245
+ " input_token_ids = np.asarray(input_token_ids, dtype=np.int32)\n",
246
+ " input_token_ids = np.expand_dims(input_token_ids, axis=0)\n",
247
+ "\n",
248
+ " # Prepare the input position to be [max_seq_len].\n",
249
+ " input_pos = [0] * self._max_seq_len\n",
250
+ " input_pos[:prefill_token_length] = range(prefill_token_length)\n",
251
+ " input_pos = np.asarray(input_pos, dtype=np.int32)\n",
252
+ "\n",
253
+ " # Initialize kv cache.\n",
254
+ " prefill_inputs = self._init_kv_cache()\n",
255
+ " prefill_inputs.update({\n",
256
+ " \"tokens\": input_token_ids,\n",
257
+ " \"input_pos\": input_pos,\n",
258
  " })\n",
259
+ " prefill_outputs = self._prefill_runner(**prefill_inputs)\n",
260
+ " if \"logits\" in prefill_outputs:\n",
261
+ " # Prefill outputs includes logits and kv cache. We only output kv cache.\n",
262
+ " prefill_outputs.pop(\"logits\")\n",
263
+ "\n",
264
+ " return prefill_outputs\n",
265
+ "\n",
266
+ " def _greedy_sampler(self, logits: np.ndarray) -> int:\n",
267
+ " return int(np.argmax(logits))\n",
268
+ "\n",
269
+ "\n",
270
+ " def _run_decode(\n",
271
+ " self,\n",
272
+ " start_pos: int,\n",
273
+ " start_token_id: int,\n",
274
+ " kv_cache: dict[str, np.ndarray],\n",
275
+ " max_decode_steps: int,\n",
276
+ " ) -> str:\n",
277
+ " \"\"\"Runs decode and outputs the token ids from greedy sampler.\n",
278
+ "\n",
279
+ " Args:\n",
280
+ " start_pos: The position of the first token of the decode input.\n",
281
+ " start_token_id: The token id of the first token of the decode input.\n",
282
+ " kv_cache: The kv cache from the prefill.\n",
283
+ " max_decode_steps: The max decode steps.\n",
284
+ "\n",
285
+ " Returns:\n",
286
+ " The token ids from the greedy sampler.\n",
287
+ " \"\"\"\n",
288
+ " next_pos = start_pos\n",
289
+ " next_token = start_token_id\n",
290
  " decode_text = []\n",
291
+ " decode_inputs = kv_cache\n",
292
+ "\n",
293
+ " for _ in range(max_decode_steps):\n",
294
+ " decode_inputs.update({\n",
295
+ " \"tokens\": np.array([[next_token]], dtype=np.int32),\n",
296
+ " \"input_pos\": np.array([next_pos], dtype=np.int32),\n",
297
+ " })\n",
298
+ " decode_outputs = self._decode_runner(**decode_inputs)\n",
299
+ " # Output logits has shape (batch=1, 1, vocab_size). We only take the first\n",
300
+ " # element.\n",
301
+ " logits = decode_outputs.pop(\"logits\")[0][0]\n",
 
 
 
302
  " next_token = self._greedy_sampler(logits)\n",
303
  " if next_token == self._tokenizer.eos_token_id:\n",
304
  " break\n",
305
  " decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=False))\n",
306
  " print(decode_text[-1], end='', flush=True)\n",
307
+ " # Decode outputs includes logits and kv cache. We already poped out\n",
308
+ " # logits, so the rest is kv cache. We pass the updated kv cache as input\n",
309
+ " # to the next decode step.\n",
310
+ " decode_inputs = decode_outputs\n",
311
+ " next_pos += 1\n",
312
+ "\n",
313
+ " print() # print a new line at the end.\n",
314
+ " return ''.join(decode_text)\n",
315
  "\n",
316
+ " def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:\n",
317
+ " messages=[{ 'role': 'user', 'content': prompt}]\n",
318
+ " token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)\n",
319
+ " # Initialize the prefill runner with the suitable input size.\n",
320
+ " self._init_prefill_runner(len(token_ids))\n",
321
  "\n",
322
+ " # Run prefill.\n",
323
+ " # Prefill up to the seond to the last token of the prompt, because the last\n",
324
+ " # token of the prompt will be used to bootstrap decode.\n",
325
+ " prefill_token_length = len(token_ids) - 1\n",
326
  "\n",
327
+ " print('Running prefill')\n",
328
+ " kv_cache = self._run_prefill(token_ids[:prefill_token_length])\n",
329
+ " # Run decode.\n",
330
+ " print('Running decode')\n",
331
+ " actual_max_decode_steps = self._max_kv_cache_seq_len - prefill_token_length - 1\n",
332
+ " if max_decode_steps is not None:\n",
333
+ " actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)\n",
334
+ " decode_text = self._run_decode(\n",
335
+ " prefill_token_length,\n",
336
+ " token_ids[prefill_token_length],\n",
337
+ " kv_cache,\n",
338
+ " actual_max_decode_steps,\n",
339
+ " )\n",
340
+ " return decode_text"
341
  ],
342
  "metadata": {
343
  "id": "UBSGrHrM4ANm"
344
  },
345
+ "execution_count": 7,
346
  "outputs": []
347
  },
348
  {
 
363
  "metadata": {
364
  "id": "AZhlDQWg61AL"
365
  },
366
+ "execution_count": 8,
367
  "outputs": []
368
  },
369
  {
 
377
  },
378
  "execution_count": null,
379
  "outputs": []
 
 
 
 
 
 
 
 
 
380
  }
381
  ]
382
  }