Aleef commited on
Commit
531c08c
·
1 Parent(s): 728a30c

Upload Godel_finetunning.ipynb

Browse files

A jupyter notebook describing how to finetunne Godel on your custom dataset. Godel is a chatbot with large model consisting about 0.75 B trainable parameters. It's a cheat sheet to fine tunning so that you can get a head start and don't have to waste your time as I did. Cheers.

Files changed (1) hide show
  1. Godel_finetunning.ipynb +679 -0
Godel_finetunning.ipynb ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "id": "cUzq1tXyk5Ga"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "# !pip install transformers\n",
12
+ "# !pip install torch\n",
13
+ "# !pip install accelerate -U"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "#### Below is the funtion to find trainable parameters of the Model. "
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 5,
26
+ "metadata": {},
27
+ "outputs": [
28
+ {
29
+ "data": {
30
+ "text/plain": [
31
+ "737641472"
32
+ ]
33
+ },
34
+ "execution_count": 5,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 1,
46
+ "metadata": {
47
+ "execution": {
48
+ "iopub.execute_input": "2023-09-12T05:38:18.853671Z",
49
+ "iopub.status.busy": "2023-09-12T05:38:18.853483Z",
50
+ "iopub.status.idle": "2023-09-12T05:38:20.511295Z",
51
+ "shell.execute_reply": "2023-09-12T05:38:20.510634Z",
52
+ "shell.execute_reply.started": "2023-09-12T05:38:18.853650Z"
53
+ },
54
+ "id": "_GqhK_n0JWC4"
55
+ },
56
+ "outputs": [],
57
+ "source": [
58
+ "import pandas as pd\n",
59
+ "import json\n",
60
+ "import torch\n"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 2,
66
+ "metadata": {
67
+ "execution": {
68
+ "iopub.execute_input": "2023-09-12T05:38:21.617293Z",
69
+ "iopub.status.busy": "2023-09-12T05:38:21.616915Z",
70
+ "iopub.status.idle": "2023-09-12T05:38:34.474328Z",
71
+ "shell.execute_reply": "2023-09-12T05:38:34.473820Z",
72
+ "shell.execute_reply.started": "2023-09-12T05:38:21.617267Z"
73
+ },
74
+ "id": "FVBPeMW99Z7G"
75
+ },
76
+ "outputs": [],
77
+ "source": [
78
+ "\n",
79
+ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW, TrainingArguments, Trainer\n",
80
+ "from torch.utils.data import TensorDataset\n",
81
+ "\n",
82
+ "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/GODEL-v1_1-large-seq2seq\", padding_side='right', truncation_side='left')\n"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": 5,
88
+ "metadata": {
89
+ "execution": {
90
+ "iopub.execute_input": "2023-09-12T05:38:37.343460Z",
91
+ "iopub.status.busy": "2023-09-12T05:38:37.343116Z",
92
+ "iopub.status.idle": "2023-09-12T05:38:43.015610Z",
93
+ "shell.execute_reply": "2023-09-12T05:38:43.015175Z",
94
+ "shell.execute_reply.started": "2023-09-12T05:38:37.343436Z"
95
+ },
96
+ "id": "Bee7KFF2MWQ_"
97
+ },
98
+ "outputs": [],
99
+ "source": [
100
+ "model = AutoModelForSeq2SeqLM.from_pretrained(\"microsoft/GODEL-v1_1-large-seq2seq\").to('cuda')"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "metadata": {},
106
+ "source": [
107
+ "#### Here the data preprocessed, Note that the data loaded to this model is in the following format. It is in the form of mulit-turn conversation between two persons.\n",
108
+ "#### [[person1, person2, person1, person2, person1, person2],\n",
109
+ "#### [person1, person2, person1, person2, person1, person2],\n",
110
+ "#### [person1, person2, person1, person2, person1, person2],\n",
111
+ "#### [person1, person2, person1, person2, person1, person2],\n",
112
+ "#### [person1, person2, person1, person2, person1, person2]]"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 6,
118
+ "metadata": {
119
+ "execution": {
120
+ "iopub.execute_input": "2023-09-12T05:38:44.400644Z",
121
+ "iopub.status.busy": "2023-09-12T05:38:44.400155Z",
122
+ "iopub.status.idle": "2023-09-12T05:38:44.405992Z",
123
+ "shell.execute_reply": "2023-09-12T05:38:44.405263Z",
124
+ "shell.execute_reply.started": "2023-09-12T05:38:44.400620Z"
125
+ },
126
+ "id": "Mjd9Us2Sr6Hq"
127
+ },
128
+ "outputs": [],
129
+ "source": [
130
+ "def read_data_from_txt(file_path):\n",
131
+ " try:\n",
132
+ " with open(file_path, 'rb') as file:\n",
133
+ " content = file.readlines()\n",
134
+ " content = [_.decode('utf-8').strip() for _ in content]\n",
135
+ " content = '\\n'.join(content)\n",
136
+ "\n",
137
+ " # Split the content based on the delimiter (triple single quotes)\n",
138
+ " data_list = content.split(\"''','''\")\n",
139
+ "\n",
140
+ " # Remove empty elements from the list\n",
141
+ " data_list = [section.strip(\"'''\") for section in data_list]\n",
142
+ " data_list = [_.strip().split('\\n') for _ in data_list]\n",
143
+ "\n",
144
+ " return data_list\n",
145
+ " except FileNotFoundError:\n",
146
+ " print(f\"File '{file_path}' not found.\")\n",
147
+ " return None\n",
148
+ " except Exception as e:\n",
149
+ " print(f\"Error occurred while reading the file: {e}\")\n",
150
+ " return None\n"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 7,
156
+ "metadata": {
157
+ "execution": {
158
+ "iopub.execute_input": "2023-09-12T05:38:45.632305Z",
159
+ "iopub.status.busy": "2023-09-12T05:38:45.631923Z",
160
+ "iopub.status.idle": "2023-09-12T05:38:45.637764Z",
161
+ "shell.execute_reply": "2023-09-12T05:38:45.637089Z",
162
+ "shell.execute_reply.started": "2023-09-12T05:38:45.632280Z"
163
+ },
164
+ "id": "N4WTX9MfKTBX"
165
+ },
166
+ "outputs": [],
167
+ "source": [
168
+ "\n",
169
+ "file_path = 'your_data.txt'\n",
170
+ "data_list = read_data_from_txt(file_path)\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 8,
176
+ "metadata": {
177
+ "execution": {
178
+ "iopub.execute_input": "2023-09-12T05:38:46.529136Z",
179
+ "iopub.status.busy": "2023-09-12T05:38:46.528726Z",
180
+ "iopub.status.idle": "2023-09-12T05:38:46.532045Z",
181
+ "shell.execute_reply": "2023-09-12T05:38:46.531505Z",
182
+ "shell.execute_reply.started": "2023-09-12T05:38:46.529112Z"
183
+ }
184
+ },
185
+ "outputs": [],
186
+ "source": [
187
+ "training_data = data_list\n"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 10,
193
+ "metadata": {
194
+ "execution": {
195
+ "iopub.execute_input": "2023-09-12T05:38:52.640741Z",
196
+ "iopub.status.busy": "2023-09-12T05:38:52.639972Z",
197
+ "iopub.status.idle": "2023-09-12T05:38:52.646245Z",
198
+ "shell.execute_reply": "2023-09-12T05:38:52.645854Z",
199
+ "shell.execute_reply.started": "2023-09-12T05:38:52.640704Z"
200
+ },
201
+ "id": "fxgyXq64Q1GP"
202
+ },
203
+ "outputs": [],
204
+ "source": [
205
+ "\n",
206
+ "def create_input_output(data_list):\n",
207
+ " input_data = []\n",
208
+ " output_data = []\n",
209
+ " instructions = \"You are Woice AI. Answer the queires relevant to rev9 Solutions only. If not relevant, asnwer 'I applogize, I can't answer your question as I am just an AI chatbot.'\"\n",
210
+ " knowledge = \"\"\n",
211
+ " for lines in data_list:\n",
212
+ " for i in range(1, len(lines), 2):\n",
213
+ " input_lines = lines[:i]\n",
214
+ " input_text = ' EOS '.join(input_lines).strip()\n",
215
+ " input_data.append(f'[INSTRUCTION] {instructions} [CONTEXT] ' + input_text )\n",
216
+ " output_data.append(lines[i] + ' EOS')\n",
217
+ " return input_data, output_data\n"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 11,
223
+ "metadata": {
224
+ "execution": {
225
+ "iopub.execute_input": "2023-09-12T05:38:54.366890Z",
226
+ "iopub.status.busy": "2023-09-12T05:38:54.366544Z",
227
+ "iopub.status.idle": "2023-09-12T05:38:54.371721Z",
228
+ "shell.execute_reply": "2023-09-12T05:38:54.371144Z",
229
+ "shell.execute_reply.started": "2023-09-12T05:38:54.366866Z"
230
+ }
231
+ },
232
+ "outputs": [],
233
+ "source": [
234
+ "\n",
235
+ "train_input, train_output = create_input_output(training_data)"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 13,
241
+ "metadata": {
242
+ "execution": {
243
+ "iopub.execute_input": "2023-09-12T05:39:10.350357Z",
244
+ "iopub.status.busy": "2023-09-12T05:39:10.350006Z",
245
+ "iopub.status.idle": "2023-09-12T05:39:10.354580Z",
246
+ "shell.execute_reply": "2023-09-12T05:39:10.353920Z",
247
+ "shell.execute_reply.started": "2023-09-12T05:39:10.350333Z"
248
+ },
249
+ "id": "VyrEDi_G9NfY"
250
+ },
251
+ "outputs": [],
252
+ "source": [
253
+ "def generation_tokenized_dataset(input, output):\n",
254
+ " \n",
255
+ " input_tokens = tokenizer(input, padding=\"longest\", truncation=True, return_tensors=\"pt\", max_length=768)\n",
256
+ " output_tokens = tokenizer(output, padding=\"longest\", truncation=True, return_tensors=\"pt\", max_length=768)\n",
257
+ " dataset = TensorDataset(input_tokens.input_ids, input_tokens.attention_mask,\n",
258
+ " output_tokens.input_ids, output_tokens.attention_mask)\n",
259
+ "\n",
260
+ " return dataset\n"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": 14,
266
+ "metadata": {
267
+ "execution": {
268
+ "iopub.execute_input": "2023-09-12T05:39:11.118317Z",
269
+ "iopub.status.busy": "2023-09-12T05:39:11.117702Z",
270
+ "iopub.status.idle": "2023-09-12T05:39:11.459556Z",
271
+ "shell.execute_reply": "2023-09-12T05:39:11.459151Z",
272
+ "shell.execute_reply.started": "2023-09-12T05:39:11.118292Z"
273
+ },
274
+ "id": "Q0IjwcBPfVEm"
275
+ },
276
+ "outputs": [],
277
+ "source": [
278
+ "train_set = generation_tokenized_dataset(train_input, train_output)\n"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 15,
284
+ "metadata": {
285
+ "execution": {
286
+ "iopub.execute_input": "2023-09-12T05:39:12.526146Z",
287
+ "iopub.status.busy": "2023-09-12T05:39:12.525838Z",
288
+ "iopub.status.idle": "2023-09-12T05:39:12.530858Z",
289
+ "shell.execute_reply": "2023-09-12T05:39:12.530178Z",
290
+ "shell.execute_reply.started": "2023-09-12T05:39:12.526123Z"
291
+ },
292
+ "id": "hhz3a3j2Sa0P"
293
+ },
294
+ "outputs": [],
295
+ "source": [
296
+ "class CustomDataCollator:\n",
297
+ " def __call__(self, features):\n",
298
+ " input_ids = torch.stack([f[0] for f in features])\n",
299
+ " attention_mask = torch.stack([f[1] for f in features])\n",
300
+ " labels = torch.stack([f[2] for f in features])\n",
301
+ "\n",
302
+ " return {\n",
303
+ " 'input_ids': input_ids,\n",
304
+ " 'attention_mask': attention_mask,\n",
305
+ " 'labels': labels\n",
306
+ " }\n"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {
313
+ "execution": {
314
+ "iopub.execute_input": "2023-09-12T05:39:13.295224Z",
315
+ "iopub.status.busy": "2023-09-12T05:39:13.294666Z",
316
+ "iopub.status.idle": "2023-09-12T05:39:13.307836Z",
317
+ "shell.execute_reply": "2023-09-12T05:39:13.307503Z",
318
+ "shell.execute_reply.started": "2023-09-12T05:39:13.295200Z"
319
+ },
320
+ "id": "CN5JWUqmS8wM"
321
+ },
322
+ "outputs": [],
323
+ "source": [
324
+ "import torch\n",
325
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
326
+ "model.to(device)\n",
327
+ "optimizer = AdamW(model.parameters(), lr=1e-5)"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 17,
333
+ "metadata": {
334
+ "execution": {
335
+ "iopub.execute_input": "2023-09-12T05:39:14.655823Z",
336
+ "iopub.status.busy": "2023-09-12T05:39:14.655033Z",
337
+ "iopub.status.idle": "2023-09-12T05:39:14.659506Z",
338
+ "shell.execute_reply": "2023-09-12T05:39:14.658681Z",
339
+ "shell.execute_reply.started": "2023-09-12T05:39:14.655786Z"
340
+ },
341
+ "id": "zfsQaXAEWZLD"
342
+ },
343
+ "outputs": [],
344
+ "source": [
345
+ "from transformers import EarlyStoppingCallback"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": 18,
351
+ "metadata": {
352
+ "execution": {
353
+ "iopub.execute_input": "2023-09-12T05:39:15.342624Z",
354
+ "iopub.status.busy": "2023-09-12T05:39:15.342125Z",
355
+ "iopub.status.idle": "2023-09-12T05:39:15.345769Z",
356
+ "shell.execute_reply": "2023-09-12T05:39:15.345059Z",
357
+ "shell.execute_reply.started": "2023-09-12T05:39:15.342600Z"
358
+ },
359
+ "id": "zd7CDp3xXVMp"
360
+ },
361
+ "outputs": [],
362
+ "source": [
363
+ "from transformers import get_linear_schedule_with_warmup"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": 17,
369
+ "metadata": {
370
+ "execution": {
371
+ "iopub.execute_input": "2023-09-11T11:42:31.617024Z",
372
+ "iopub.status.busy": "2023-09-11T11:42:31.616702Z",
373
+ "iopub.status.idle": "2023-09-11T11:42:31.620157Z",
374
+ "shell.execute_reply": "2023-09-11T11:42:31.619476Z",
375
+ "shell.execute_reply.started": "2023-09-11T11:42:31.617001Z"
376
+ },
377
+ "id": "rcMlWRgMWcOA"
378
+ },
379
+ "outputs": [],
380
+ "source": [
381
+ "callbacks = [EarlyStoppingCallback(early_stopping_patience=4)]"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": 19,
387
+ "metadata": {
388
+ "execution": {
389
+ "iopub.execute_input": "2023-09-12T05:39:17.359370Z",
390
+ "iopub.status.busy": "2023-09-12T05:39:17.358967Z",
391
+ "iopub.status.idle": "2023-09-12T05:39:17.362640Z",
392
+ "shell.execute_reply": "2023-09-12T05:39:17.362096Z",
393
+ "shell.execute_reply.started": "2023-09-12T05:39:17.359346Z"
394
+ },
395
+ "id": "WgGbwECpXXwd"
396
+ },
397
+ "outputs": [],
398
+ "source": [
399
+ "lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,\n",
400
+ " num_warmup_steps=300,\n",
401
+ " num_training_steps=1200)"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": 20,
407
+ "metadata": {
408
+ "execution": {
409
+ "iopub.execute_input": "2023-09-12T05:39:26.782170Z",
410
+ "iopub.status.busy": "2023-09-12T05:39:26.781759Z",
411
+ "iopub.status.idle": "2023-09-12T05:39:26.788708Z",
412
+ "shell.execute_reply": "2023-09-12T05:39:26.788007Z",
413
+ "shell.execute_reply.started": "2023-09-12T05:39:26.782126Z"
414
+ },
415
+ "id": "UCpUorNtUTxJ"
416
+ },
417
+ "outputs": [],
418
+ "source": [
419
+ "training_args = TrainingArguments(\n",
420
+ " output_dir='./godel/v0.0.5',\n",
421
+ " num_train_epochs= 20,\n",
422
+ " per_device_train_batch_size=2,\n",
423
+ " warmup_steps=100,\n",
424
+ " weight_decay=0.01,\n",
425
+ " logging_dir='./godel/v0.0.5/logs',\n",
426
+ " logging_steps=50,\n",
427
+ " save_total_limit=1,\n",
428
+ " gradient_accumulation_steps=8,\n",
429
+ " learning_rate=0.001,\n",
430
+ " load_best_model_at_end=True,\n",
431
+ " metric_for_best_model='loss',\n",
432
+ " greater_is_better=False,\n",
433
+ " save_strategy='epoch',\n",
434
+ " evaluation_strategy='epoch'\n",
435
+ "\n",
436
+ ")\n",
437
+ "\n",
438
+ "training_args = training_args.set_lr_scheduler(name='linear',\n",
439
+ " num_epochs=40,\n",
440
+ " warmup_steps=100)\n"
441
+ ]
442
+ },
443
+ {
444
+ "cell_type": "markdown",
445
+ "metadata": {},
446
+ "source": [
447
+ "#### Here model is evaluated and trained on the same dataset as I was short on the dataset. If you have a large dataset, split them with the desired ratio (recommended= 15:85, respectively)"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": 21,
453
+ "metadata": {
454
+ "execution": {
455
+ "iopub.execute_input": "2023-09-12T05:39:27.630008Z",
456
+ "iopub.status.busy": "2023-09-12T05:39:27.629250Z",
457
+ "iopub.status.idle": "2023-09-12T05:39:27.642183Z",
458
+ "shell.execute_reply": "2023-09-12T05:39:27.641782Z",
459
+ "shell.execute_reply.started": "2023-09-12T05:39:27.629973Z"
460
+ },
461
+ "id": "KxAyHTuJOBIQ"
462
+ },
463
+ "outputs": [],
464
+ "source": [
465
+ "\n",
466
+ "\n",
467
+ "trainer = Trainer(\n",
468
+ " model=model,\n",
469
+ " args=training_args,\n",
470
+ " train_dataset=train_set,\n",
471
+ " eval_dataset=train_set,\n",
472
+ " data_collator=CustomDataCollator(),\n",
473
+ " callbacks=callbacks,\n",
474
+ "\n",
475
+ ")"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": null,
481
+ "metadata": {
482
+ "execution": {
483
+ "iopub.execute_input": "2023-09-12T05:39:29.327544Z",
484
+ "iopub.status.busy": "2023-09-12T05:39:29.327023Z",
485
+ "iopub.status.idle": "2023-09-12T09:31:20.343378Z",
486
+ "shell.execute_reply": "2023-09-12T09:31:20.343016Z",
487
+ "shell.execute_reply.started": "2023-09-12T05:39:29.327521Z"
488
+ },
489
+ "id": "brO0zCjN9U_P"
490
+ },
491
+ "outputs": [],
492
+ "source": [
493
+ "trainer.train()"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": 23,
499
+ "metadata": {
500
+ "execution": {
501
+ "iopub.execute_input": "2023-09-12T09:31:20.344170Z",
502
+ "iopub.status.busy": "2023-09-12T09:31:20.344000Z",
503
+ "iopub.status.idle": "2023-09-12T09:32:40.040850Z",
504
+ "shell.execute_reply": "2023-09-12T09:32:40.040458Z",
505
+ "shell.execute_reply.started": "2023-09-12T09:31:20.344157Z"
506
+ }
507
+ },
508
+ "outputs": [
509
+ {
510
+ "data": {
511
+ "text/html": [
512
+ "\n",
513
+ " <div>\n",
514
+ " \n",
515
+ " <progress value='160' max='160' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
516
+ " [160/160 01:19]\n",
517
+ " </div>\n",
518
+ " "
519
+ ],
520
+ "text/plain": [
521
+ "<IPython.core.display.HTML object>"
522
+ ]
523
+ },
524
+ "metadata": {},
525
+ "output_type": "display_data"
526
+ },
527
+ {
528
+ "data": {
529
+ "text/plain": [
530
+ "{'eval_loss': 0.00055647426052019,\n",
531
+ " 'eval_runtime': 79.6939,\n",
532
+ " 'eval_samples_per_second': 16.036,\n",
533
+ " 'eval_steps_per_second': 2.008,\n",
534
+ " 'epoch': 39.56}"
535
+ ]
536
+ },
537
+ "execution_count": 23,
538
+ "metadata": {},
539
+ "output_type": "execute_result"
540
+ }
541
+ ],
542
+ "source": [
543
+ "trainer.evaluate(train_set)"
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "code",
548
+ "execution_count": 24,
549
+ "metadata": {
550
+ "execution": {
551
+ "iopub.execute_input": "2023-09-12T09:33:05.820118Z",
552
+ "iopub.status.busy": "2023-09-12T09:33:05.819417Z",
553
+ "iopub.status.idle": "2023-09-12T09:33:08.026572Z",
554
+ "shell.execute_reply": "2023-09-12T09:33:08.026139Z",
555
+ "shell.execute_reply.started": "2023-09-12T09:33:05.820082Z"
556
+ }
557
+ },
558
+ "outputs": [
559
+ {
560
+ "data": {
561
+ "text/plain": [
562
+ "('./godel/v0.0.5/tokenizer_config.json',\n",
563
+ " './godel/v0.0.5/special_tokens_map.json',\n",
564
+ " './godel/v0.0.5/tokenizer.json')"
565
+ ]
566
+ },
567
+ "execution_count": 24,
568
+ "metadata": {},
569
+ "output_type": "execute_result"
570
+ }
571
+ ],
572
+ "source": [
573
+ "trainer.save_model()\n",
574
+ "trainer.save_state()\n",
575
+ "tokenizer.save_pretrained(trainer.args.output_dir)"
576
+ ]
577
+ },
578
+ {
579
+ "cell_type": "markdown",
580
+ "metadata": {},
581
+ "source": [
582
+ "#### You can chat with your model here. Pass in instrucions or knowledge as you desire."
583
+ ]
584
+ },
585
+ {
586
+ "cell_type": "code",
587
+ "execution_count": 25,
588
+ "metadata": {
589
+ "execution": {
590
+ "iopub.execute_input": "2023-09-12T09:33:11.243375Z",
591
+ "iopub.status.busy": "2023-09-12T09:33:11.242979Z",
592
+ "iopub.status.idle": "2023-09-12T09:33:11.246636Z",
593
+ "shell.execute_reply": "2023-09-12T09:33:11.246071Z",
594
+ "shell.execute_reply.started": "2023-09-12T09:33:11.243351Z"
595
+ }
596
+ },
597
+ "outputs": [],
598
+ "source": [
599
+ "from time import time "
600
+ ]
601
+ },
602
+ {
603
+ "cell_type": "code",
604
+ "execution_count": 26,
605
+ "metadata": {
606
+ "execution": {
607
+ "iopub.execute_input": "2023-09-12T09:33:11.802465Z",
608
+ "iopub.status.busy": "2023-09-12T09:33:11.802159Z",
609
+ "iopub.status.idle": "2023-09-12T09:33:11.807265Z",
610
+ "shell.execute_reply": "2023-09-12T09:33:11.806707Z",
611
+ "shell.execute_reply.started": "2023-09-12T09:33:11.802443Z"
612
+ }
613
+ },
614
+ "outputs": [],
615
+ "source": [
616
+ "def generate(instruction, dialog, knowledge):\n",
617
+ " if knowledge != '':\n",
618
+ " knowledge = '[KNOWLEDGE] ' + knowledge\n",
619
+ " dialog = ' EOS '.join(dialog)\n",
620
+ " query = f\"{instruction} [CONTEXT] {dialog} {knowledge}\"\n",
621
+ " t = time()\n",
622
+ " \n",
623
+ " input_ids = tokenizer(f\"{query}\", return_tensors=\"pt\").to('cuda').input_ids\n",
624
+ " outputs = model.generate(input_ids, max_length=32102, min_length=8, top_p=0.9, do_sample=True)\n",
625
+ " output = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
626
+ " print('time:', time() - t)\n",
627
+ " return output"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "code",
632
+ "execution_count": null,
633
+ "metadata": {
634
+ "execution": {
635
+ "iopub.execute_input": "2023-09-12T09:41:13.476490Z",
636
+ "iopub.status.busy": "2023-09-12T09:41:13.476127Z"
637
+ }
638
+ },
639
+ "outputs": [],
640
+ "source": [
641
+ "dialog = list()\n",
642
+ "while True:\n",
643
+ " query = input(\"Human: \")\n",
644
+ " dialog.append(query)\n",
645
+ " instruction = \"You are Woice AI, you are here to answer queries emphatically. Don't be rude and say vulgar words. Any thing unrelated to your training, do not answer randomly. Be polite.\"\n",
646
+ " knowledge = ''\n",
647
+ " output = \"AI: \" + generate(instruction, dialog, knowledge)\n",
648
+ " dialog.append(output)\n",
649
+ " print(output)"
650
+ ]
651
+ }
652
+ ],
653
+ "metadata": {
654
+ "accelerator": "GPU",
655
+ "colab": {
656
+ "gpuType": "T4",
657
+ "provenance": []
658
+ },
659
+ "kernelspec": {
660
+ "display_name": "Python 3 (ipykernel)",
661
+ "language": "python",
662
+ "name": "python3"
663
+ },
664
+ "language_info": {
665
+ "codemirror_mode": {
666
+ "name": "ipython",
667
+ "version": 3
668
+ },
669
+ "file_extension": ".py",
670
+ "mimetype": "text/x-python",
671
+ "name": "python",
672
+ "nbconvert_exporter": "python",
673
+ "pygments_lexer": "ipython3",
674
+ "version": "3.11.4"
675
+ }
676
+ },
677
+ "nbformat": 4,
678
+ "nbformat_minor": 4
679
+ }