AmelieSchreiber commited on
Commit
ce9ca82
1 Parent(s): 907da65

Upload cafa_5_finetune_v2.ipynb

Browse files
Files changed (1) hide show
  1. cafa_5_finetune_v2.ipynb +465 -0
cafa_5_finetune_v2.ipynb ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Finetuneing ESM-2 Models for CAFA-5"
9
+ ]
10
+ },
11
+ {
12
+ "attachments": {},
13
+ "cell_type": "markdown",
14
+ "metadata": {},
15
+ "source": [
16
+ "## Finetune an ESM-2 Model"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "import torch\n",
26
+ "from torch.utils.data import DataLoader, Dataset\n",
27
+ "from transformers import AutoTokenizer, EsmForSequenceClassification\n",
28
+ "from accelerate import Accelerator\n",
29
+ "from sklearn.model_selection import train_test_split\n",
30
+ "from torchmetrics.classification import MultilabelF1Score\n",
31
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, average_precision_score\n",
32
+ "import datetime\n",
33
+ "import pandas as pd\n",
34
+ "\n",
35
+ "# Load the data\n",
36
+ "data = pd.read_csv(\"C:/Users/OWO/Desktop/amelie_vscode/cafa5/data/merged_protein_data.tsv\", sep=\"\\t\")\n",
37
+ "# Use only the first 100 entries\n",
38
+ "# data = data.head(100)\n",
39
+ "\n",
40
+ "# Initialize the accelerator\n",
41
+ "accelerator = Accelerator()\n",
42
+ "device = accelerator.device\n",
43
+ "\n",
44
+ "# Data Preprocessing\n",
45
+ "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n",
46
+ "MAX_LENGTH = tokenizer.model_max_length\n",
47
+ "NUM_EPOCHS = 3\n",
48
+ "LR = 5e-4\n",
49
+ "BATCH_SIZE = 2\n",
50
+ "\n",
51
+ "class ProteinDataset(Dataset):\n",
52
+ " def __init__(self, sequences, labels):\n",
53
+ " self.sequences = sequences\n",
54
+ " self.labels = labels\n",
55
+ "\n",
56
+ " def __len__(self):\n",
57
+ " return len(self.sequences)\n",
58
+ "\n",
59
+ " def __getitem__(self, idx):\n",
60
+ " sequence = self.sequences[idx]\n",
61
+ " label = self.labels[idx]\n",
62
+ " encoding = tokenizer(sequence, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=MAX_LENGTH)\n",
63
+ " return {\n",
64
+ " 'input_ids': encoding['input_ids'].flatten(),\n",
65
+ " 'attention_mask': encoding['attention_mask'].flatten(),\n",
66
+ " 'labels': torch.tensor(label, dtype=torch.float)\n",
67
+ " }\n",
68
+ "\n",
69
+ "def encode_labels(go_terms, unique_terms):\n",
70
+ " encoded = []\n",
71
+ " for terms in go_terms:\n",
72
+ " encoding = [1 if term in terms else 0 for term in unique_terms]\n",
73
+ " encoded.append(encoding)\n",
74
+ " return encoded\n",
75
+ "\n",
76
+ "train_sequences, val_sequences, train_labels, val_labels = train_test_split(data['sequence'], data['term'], test_size=0.1)\n",
77
+ "\n",
78
+ "# Reset the indices\n",
79
+ "train_sequences = train_sequences.reset_index(drop=True)\n",
80
+ "val_sequences = val_sequences.reset_index(drop=True)\n",
81
+ "train_labels = train_labels.reset_index(drop=True)\n",
82
+ "val_labels = val_labels.reset_index(drop=True)\n",
83
+ "\n",
84
+ "unique_terms = list(set(term for sublist in data['term'] for term in sublist))\n",
85
+ "train_labels_encoded = encode_labels(train_labels, unique_terms)\n",
86
+ "val_labels_encoded = encode_labels(val_labels, unique_terms)\n",
87
+ "\n",
88
+ "train_dataset = ProteinDataset(train_sequences, train_labels_encoded)\n",
89
+ "val_dataset = ProteinDataset(val_sequences, val_labels_encoded)\n",
90
+ "\n",
91
+ "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
92
+ "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)\n",
93
+ "\n",
94
+ "# Model Training\n",
95
+ "model = EsmForSequenceClassification.from_pretrained(\"facebook/esm2_t6_8M_UR50D\", num_labels=len(unique_terms), problem_type=\"multi_label_classification\")\n",
96
+ "model = model.to(device)\n",
97
+ "model.train()\n",
98
+ "\n",
99
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n",
100
+ "optimizer, model = accelerator.prepare(optimizer, model)\n",
101
+ "\n",
102
+ "# Initialize metrics\n",
103
+ "f1_metric = MultilabelF1Score(num_labels=len(unique_terms), threshold=0.5)\n",
104
+ "f1_metric = f1_metric.to(device)\n",
105
+ "\n",
106
+ "num_epochs = NUM_EPOCHS\n",
107
+ "\n",
108
+ "for epoch in range(num_epochs):\n",
109
+ " total_loss = 0\n",
110
+ " for batch in train_loader:\n",
111
+ " optimizer.zero_grad()\n",
112
+ " input_ids = batch['input_ids'].to(device)\n",
113
+ " attention_mask = batch['attention_mask'].to(device)\n",
114
+ " labels = batch['labels'].to(device)\n",
115
+ "\n",
116
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
117
+ " loss = outputs.loss\n",
118
+ " accelerator.backward(loss)\n",
119
+ " optimizer.step()\n",
120
+ "\n",
121
+ " total_loss += loss.item()\n",
122
+ "\n",
123
+ " print(f'Epoch {epoch + 1}/{num_epochs}, Training loss: {total_loss/len(train_loader)}')\n",
124
+ "\n",
125
+ " model.eval()\n",
126
+ " predictions = []\n",
127
+ " true_labels_list = []\n",
128
+ " with torch.no_grad():\n",
129
+ " for batch in val_loader:\n",
130
+ " input_ids = batch['input_ids'].to(device)\n",
131
+ " attention_mask = batch['attention_mask'].to(device)\n",
132
+ " labels = batch['labels'].to(device)\n",
133
+ "\n",
134
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
135
+ " logits = outputs.logits\n",
136
+ " predictions.append(torch.sigmoid(logits))\n",
137
+ " true_labels_list.append(labels)\n",
138
+ "\n",
139
+ " predictions_tensor = torch.cat(predictions, dim=0).cpu().numpy()\n",
140
+ " true_labels_tensor = torch.cat(true_labels_list, dim=0).cpu().numpy()\n",
141
+ "\n",
142
+ " threshold = 0.5\n",
143
+ " predictions_bin = (predictions_tensor > threshold).astype(int)\n",
144
+ "\n",
145
+ " # Compute metrics\n",
146
+ " val_f1 = f1_metric(torch.tensor(predictions_tensor).to(device), torch.tensor(true_labels_tensor).to(device))\n",
147
+ " val_accuracy = accuracy_score(true_labels_tensor.flatten(), predictions_bin.flatten())\n",
148
+ " val_precision = precision_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n",
149
+ " val_recall = recall_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n",
150
+ " val_auc = average_precision_score(true_labels_tensor, predictions_tensor, average='micro')\n",
151
+ "\n",
152
+ " # Print metrics\n",
153
+ " print(f'Validation F1 Score: {val_f1}')\n",
154
+ " print(f'Validation Accuracy: {val_accuracy}')\n",
155
+ " print(f'Validation Precision: {val_precision}')\n",
156
+ " print(f'Validation Recall: {val_recall}')\n",
157
+ " print(f'Validation AUC: {val_auc}')\n",
158
+ "\n",
159
+ " timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n",
160
+ " model_path = f'./esm2_t6_8M_finetuned_cafa5_{timestamp}'\n",
161
+ " model.save_pretrained(model_path)\n",
162
+ " tokenizer.save_pretrained(model_path)\n",
163
+ "\n",
164
+ " print(f'Model checkpoint saved to {model_path}')\n"
165
+ ]
166
+ },
167
+ {
168
+ "attachments": {},
169
+ "cell_type": "markdown",
170
+ "metadata": {},
171
+ "source": [
172
+ "## Save the Train/Validation Split Data"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "import pickle\n",
182
+ "\n",
183
+ "# After you've created the train and validation splits:\n",
184
+ "data_splits = {\n",
185
+ " \"train_sequences\": train_sequences,\n",
186
+ " \"val_sequences\": val_sequences,\n",
187
+ " \"train_labels\": train_labels,\n",
188
+ " \"val_labels\": val_labels\n",
189
+ "}\n",
190
+ "\n",
191
+ "with open('data_splits.pkl', 'wb') as file:\n",
192
+ " pickle.dump(data_splits, file)\n"
193
+ ]
194
+ },
195
+ {
196
+ "attachments": {},
197
+ "cell_type": "markdown",
198
+ "metadata": {},
199
+ "source": [
200
+ "## Reload the Data Later"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": [
209
+ "import pickle\n",
210
+ "\n",
211
+ "# Load the data splits\n",
212
+ "with open('data_splits.pkl', 'rb') as file:\n",
213
+ " data_splits = pickle.load(file)\n",
214
+ "\n",
215
+ "train_sequences = data_splits[\"train_sequences\"]\n",
216
+ "val_sequences = data_splits[\"val_sequences\"]\n",
217
+ "train_labels = data_splits[\"train_labels\"]\n",
218
+ "val_labels = data_splits[\"val_labels\"]\n",
219
+ "\n",
220
+ "# Now, the rest of your code can proceed as it is, \n",
221
+ "# with the train and validation sets loaded from the pickle file."
222
+ ]
223
+ },
224
+ {
225
+ "attachments": {},
226
+ "cell_type": "markdown",
227
+ "metadata": {},
228
+ "source": [
229
+ "## Data Preprocessing"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "import torch\n",
239
+ "from torch.utils.data import DataLoader, Dataset\n",
240
+ "from transformers import AutoTokenizer, EsmForSequenceClassification\n",
241
+ "from accelerate import Accelerator\n",
242
+ "from sklearn.model_selection import train_test_split\n",
243
+ "from torchmetrics.classification import MultilabelF1Score\n",
244
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, average_precision_score\n",
245
+ "import datetime\n",
246
+ "import pandas as pd\n",
247
+ "\n",
248
+ "# Load the data\n",
249
+ "data = pd.read_csv(\"C:/Users/OWO/Desktop/amelie_vscode/cafa5/data/merged_protein_data.tsv\", sep=\"\\t\")\n",
250
+ "# Use only the first 100 entries\n",
251
+ "data = data.head(100)\n",
252
+ "\n",
253
+ "# Initialize the accelerator\n",
254
+ "accelerator = Accelerator()\n",
255
+ "device = accelerator.device\n",
256
+ "\n",
257
+ "# Data Preprocessing\n",
258
+ "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n",
259
+ "MAX_LENGTH = tokenizer.model_max_length\n",
260
+ "\n",
261
+ "class ProteinDataset(Dataset):\n",
262
+ " def __init__(self, sequences, labels):\n",
263
+ " self.sequences = sequences\n",
264
+ " self.labels = labels\n",
265
+ "\n",
266
+ " def __len__(self):\n",
267
+ " return len(self.sequences)\n",
268
+ "\n",
269
+ " def __getitem__(self, idx):\n",
270
+ " sequence = self.sequences[idx]\n",
271
+ " label = self.labels[idx]\n",
272
+ " encoding = tokenizer(sequence, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=MAX_LENGTH)\n",
273
+ " return {\n",
274
+ " 'input_ids': encoding['input_ids'].flatten(),\n",
275
+ " 'attention_mask': encoding['attention_mask'].flatten(),\n",
276
+ " 'labels': torch.tensor(label, dtype=torch.float)\n",
277
+ " }\n",
278
+ "\n",
279
+ "def encode_labels(go_terms, unique_terms):\n",
280
+ " encoded = []\n",
281
+ " for terms in go_terms:\n",
282
+ " encoding = [1 if term in terms else 0 for term in unique_terms]\n",
283
+ " encoded.append(encoding)\n",
284
+ " return encoded\n",
285
+ "\n",
286
+ "# train_sequences, val_sequences, train_labels, val_labels = train_test_split(data['sequence'], data['term'], test_size=0.1)\n",
287
+ "\n",
288
+ "# Reset the indices\n",
289
+ "# train_sequences = train_sequences.reset_index(drop=True)\n",
290
+ "# val_sequences = val_sequences.reset_index(drop=True)\n",
291
+ "# train_labels = train_labels.reset_index(drop=True)\n",
292
+ "# val_labels = val_labels.reset_index(drop=True)\n",
293
+ "\n",
294
+ "unique_terms = list(set(term for sublist in data['term'] for term in sublist))\n",
295
+ "train_labels_encoded = encode_labels(train_labels, unique_terms)\n",
296
+ "val_labels_encoded = encode_labels(val_labels, unique_terms)\n",
297
+ "\n",
298
+ "train_dataset = ProteinDataset(train_sequences, train_labels_encoded)\n",
299
+ "val_dataset = ProteinDataset(val_sequences, val_labels_encoded)\n",
300
+ "\n",
301
+ "train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)\n",
302
+ "val_loader = DataLoader(val_dataset, batch_size=2)"
303
+ ]
304
+ },
305
+ {
306
+ "attachments": {},
307
+ "cell_type": "markdown",
308
+ "metadata": {},
309
+ "source": [
310
+ "## Fine-tune with LoRA"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "from collections import Counter\n",
320
+ "from peft import get_peft_config, get_peft_model, LoraConfig\n",
321
+ "import datetime\n",
322
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, hamming_loss, average_precision_score\n",
323
+ "from torchmetrics.classification import MultilabelF1Score\n",
324
+ "\n",
325
+ "# Constants\n",
326
+ "MODEL_NAME = \"facebook/esm2_t6_8M_UR50D\" # Replace with your trained model above\n",
327
+ "BATCH_SIZE = 4\n",
328
+ "NUM_EPOCHS = 7\n",
329
+ "LR = 3e-5\n",
330
+ "\n",
331
+ "# Initialize model with LoRA\n",
332
+ "peft_config = LoraConfig(\n",
333
+ " task_type=\"SEQ_CLS\", \n",
334
+ " inference_mode=False, \n",
335
+ " r=16, \n",
336
+ " bias=\"none\",\n",
337
+ " lora_alpha=16, \n",
338
+ " lora_dropout=0.1, \n",
339
+ " target_modules=[\"query\", \"key\", \"value\"]\n",
340
+ ")\n",
341
+ "\n",
342
+ "base_model = EsmForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(unique_terms), problem_type=\"multi_label_classification\")\n",
343
+ "model = get_peft_model(base_model, peft_config)\n",
344
+ "model = model.to(accelerator.device)\n",
345
+ "\n",
346
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n",
347
+ "optimizer, model = accelerator.prepare(optimizer, model)\n",
348
+ "\n",
349
+ "f1_metric = MultilabelF1Score(num_labels=len(unique_terms), threshold=0.5)\n",
350
+ "f1_metric = f1_metric.to(device)\n",
351
+ "\n",
352
+ "# Compute Class Weights\n",
353
+ "def compute_class_weights(terms, term_to_id):\n",
354
+ " all_terms = [term for terms_list in terms for term in terms_list]\n",
355
+ " term_counts = Counter(all_terms)\n",
356
+ " total_terms = sum(term_counts.values())\n",
357
+ " class_weights = {term: total_terms / count for term, count in term_counts.items()}\n",
358
+ " weights = torch.tensor([class_weights[term] for term in term_to_id.keys()], dtype=torch.float)\n",
359
+ " normalized_weights = weights / weights.sum()\n",
360
+ " return normalized_weights\n",
361
+ "\n",
362
+ "term_to_id = {term: idx for idx, term in enumerate(unique_terms)}\n",
363
+ "all_terms_combined = train_labels.tolist() + val_labels.tolist()\n",
364
+ "weights = compute_class_weights(all_terms_combined, term_to_id)\n",
365
+ "weights = weights.to(accelerator.device)\n",
366
+ "loss_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)\n",
367
+ "\n",
368
+ "# Training loop\n",
369
+ "for epoch in range(NUM_EPOCHS):\n",
370
+ " # Training Phase\n",
371
+ " model.train()\n",
372
+ " total_train_loss = 0\n",
373
+ " for batch in train_loader:\n",
374
+ " optimizer.zero_grad()\n",
375
+ " input_ids = batch['input_ids'].to(accelerator.device)\n",
376
+ " attention_mask = batch['attention_mask'].to(accelerator.device)\n",
377
+ " labels = batch['labels'].to(accelerator.device)\n",
378
+ "\n",
379
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
380
+ " logits = outputs.logits\n",
381
+ " loss = loss_criterion(logits, labels)\n",
382
+ " accelerator.backward(loss)\n",
383
+ " optimizer.step()\n",
384
+ "\n",
385
+ " total_train_loss += loss.item()\n",
386
+ "\n",
387
+ " avg_train_loss = total_train_loss / len(train_loader)\n",
388
+ "\n",
389
+ " # Validation Phase\n",
390
+ " model.eval()\n",
391
+ " total_val_loss = 0\n",
392
+ " predictions = []\n",
393
+ " true_labels = []\n",
394
+ " with torch.no_grad():\n",
395
+ " for batch in val_loader:\n",
396
+ " input_ids = batch['input_ids'].to(accelerator.device)\n",
397
+ " attention_mask = batch['attention_mask'].to(accelerator.device)\n",
398
+ " labels = batch['labels'].to(accelerator.device)\n",
399
+ "\n",
400
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
401
+ " logits = outputs.logits\n",
402
+ " loss = loss_criterion(logits, labels)\n",
403
+ "\n",
404
+ " total_val_loss += loss.item()\n",
405
+ " predictions.append(torch.sigmoid(logits).detach())\n",
406
+ " true_labels.append(labels.detach())\n",
407
+ "\n",
408
+ "\n",
409
+ " avg_val_loss = total_val_loss / len(val_loader)\n",
410
+ " \n",
411
+ " predictions_tensor = torch.cat(predictions, dim=0).cpu().numpy()\n",
412
+ " true_labels_tensor = torch.cat(true_labels, dim=0).cpu().numpy()\n",
413
+ "\n",
414
+ " threshold = 0.5\n",
415
+ " predictions_bin = (predictions_tensor > threshold).astype(int)\n",
416
+ "\n",
417
+ " val_f1 = f1_metric(torch.tensor(predictions_tensor).to(device), torch.tensor(true_labels_tensor).to(device))\n",
418
+ " val_accuracy = accuracy_score(true_labels_tensor.flatten(), predictions_bin.flatten())\n",
419
+ " val_precision = precision_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n",
420
+ " val_recall = recall_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n",
421
+ " val_auc = average_precision_score(true_labels_tensor, predictions_tensor, average='micro')\n",
422
+ "\n",
423
+ " print(f\"Epoch {epoch + 1}/{NUM_EPOCHS} - Training Loss: {avg_train_loss:.4f} - Validation Loss: {avg_val_loss:.4f}\")\n",
424
+ " print(f\"Validation Metrics - Accuracy: {val_accuracy:.4f} - Precision (Micro): {val_precision:.4f} - Recall (Micro): {val_recall:.4f} - AUC: {val_auc:.4f} - F1 Score: {val_f1:.4f}\")\n",
425
+ "\n",
426
+ " timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n",
427
+ " # Save model and tokenizer. Note that Accelerator has a save method for models.\n",
428
+ " model_path = f'./esm2_t6_8M_cafa5_lora_{timestamp}'\n",
429
+ " model.save_pretrained(model_path)\n",
430
+ " tokenizer.save_pretrained(model_path)\n",
431
+ " model.base_model.save_pretrained(model_path)\n",
432
+ " print(f'Model checkpoint saved to {model_path}')\n"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "metadata": {},
439
+ "outputs": [],
440
+ "source": []
441
+ }
442
+ ],
443
+ "metadata": {
444
+ "kernelspec": {
445
+ "display_name": "cafa_5",
446
+ "language": "python",
447
+ "name": "python3"
448
+ },
449
+ "language_info": {
450
+ "codemirror_mode": {
451
+ "name": "ipython",
452
+ "version": 3
453
+ },
454
+ "file_extension": ".py",
455
+ "mimetype": "text/x-python",
456
+ "name": "python",
457
+ "nbconvert_exporter": "python",
458
+ "pygments_lexer": "ipython3",
459
+ "version": "3.9.17"
460
+ },
461
+ "orig_nbformat": 4
462
+ },
463
+ "nbformat": 4,
464
+ "nbformat_minor": 2
465
+ }