AmelieSchreiber
commited on
Commit
•
ce9ca82
1
Parent(s):
907da65
Upload cafa_5_finetune_v2.ipynb
Browse files- cafa_5_finetune_v2.ipynb +465 -0
cafa_5_finetune_v2.ipynb
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"attachments": {},
|
5 |
+
"cell_type": "markdown",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Finetuneing ESM-2 Models for CAFA-5"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"attachments": {},
|
13 |
+
"cell_type": "markdown",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"## Finetune an ESM-2 Model"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": null,
|
22 |
+
"metadata": {},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"import torch\n",
|
26 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
27 |
+
"from transformers import AutoTokenizer, EsmForSequenceClassification\n",
|
28 |
+
"from accelerate import Accelerator\n",
|
29 |
+
"from sklearn.model_selection import train_test_split\n",
|
30 |
+
"from torchmetrics.classification import MultilabelF1Score\n",
|
31 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, average_precision_score\n",
|
32 |
+
"import datetime\n",
|
33 |
+
"import pandas as pd\n",
|
34 |
+
"\n",
|
35 |
+
"# Load the data\n",
|
36 |
+
"data = pd.read_csv(\"C:/Users/OWO/Desktop/amelie_vscode/cafa5/data/merged_protein_data.tsv\", sep=\"\\t\")\n",
|
37 |
+
"# Use only the first 100 entries\n",
|
38 |
+
"# data = data.head(100)\n",
|
39 |
+
"\n",
|
40 |
+
"# Initialize the accelerator\n",
|
41 |
+
"accelerator = Accelerator()\n",
|
42 |
+
"device = accelerator.device\n",
|
43 |
+
"\n",
|
44 |
+
"# Data Preprocessing\n",
|
45 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n",
|
46 |
+
"MAX_LENGTH = tokenizer.model_max_length\n",
|
47 |
+
"NUM_EPOCHS = 3\n",
|
48 |
+
"LR = 5e-4\n",
|
49 |
+
"BATCH_SIZE = 2\n",
|
50 |
+
"\n",
|
51 |
+
"class ProteinDataset(Dataset):\n",
|
52 |
+
" def __init__(self, sequences, labels):\n",
|
53 |
+
" self.sequences = sequences\n",
|
54 |
+
" self.labels = labels\n",
|
55 |
+
"\n",
|
56 |
+
" def __len__(self):\n",
|
57 |
+
" return len(self.sequences)\n",
|
58 |
+
"\n",
|
59 |
+
" def __getitem__(self, idx):\n",
|
60 |
+
" sequence = self.sequences[idx]\n",
|
61 |
+
" label = self.labels[idx]\n",
|
62 |
+
" encoding = tokenizer(sequence, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=MAX_LENGTH)\n",
|
63 |
+
" return {\n",
|
64 |
+
" 'input_ids': encoding['input_ids'].flatten(),\n",
|
65 |
+
" 'attention_mask': encoding['attention_mask'].flatten(),\n",
|
66 |
+
" 'labels': torch.tensor(label, dtype=torch.float)\n",
|
67 |
+
" }\n",
|
68 |
+
"\n",
|
69 |
+
"def encode_labels(go_terms, unique_terms):\n",
|
70 |
+
" encoded = []\n",
|
71 |
+
" for terms in go_terms:\n",
|
72 |
+
" encoding = [1 if term in terms else 0 for term in unique_terms]\n",
|
73 |
+
" encoded.append(encoding)\n",
|
74 |
+
" return encoded\n",
|
75 |
+
"\n",
|
76 |
+
"train_sequences, val_sequences, train_labels, val_labels = train_test_split(data['sequence'], data['term'], test_size=0.1)\n",
|
77 |
+
"\n",
|
78 |
+
"# Reset the indices\n",
|
79 |
+
"train_sequences = train_sequences.reset_index(drop=True)\n",
|
80 |
+
"val_sequences = val_sequences.reset_index(drop=True)\n",
|
81 |
+
"train_labels = train_labels.reset_index(drop=True)\n",
|
82 |
+
"val_labels = val_labels.reset_index(drop=True)\n",
|
83 |
+
"\n",
|
84 |
+
"unique_terms = list(set(term for sublist in data['term'] for term in sublist))\n",
|
85 |
+
"train_labels_encoded = encode_labels(train_labels, unique_terms)\n",
|
86 |
+
"val_labels_encoded = encode_labels(val_labels, unique_terms)\n",
|
87 |
+
"\n",
|
88 |
+
"train_dataset = ProteinDataset(train_sequences, train_labels_encoded)\n",
|
89 |
+
"val_dataset = ProteinDataset(val_sequences, val_labels_encoded)\n",
|
90 |
+
"\n",
|
91 |
+
"train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
|
92 |
+
"val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)\n",
|
93 |
+
"\n",
|
94 |
+
"# Model Training\n",
|
95 |
+
"model = EsmForSequenceClassification.from_pretrained(\"facebook/esm2_t6_8M_UR50D\", num_labels=len(unique_terms), problem_type=\"multi_label_classification\")\n",
|
96 |
+
"model = model.to(device)\n",
|
97 |
+
"model.train()\n",
|
98 |
+
"\n",
|
99 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n",
|
100 |
+
"optimizer, model = accelerator.prepare(optimizer, model)\n",
|
101 |
+
"\n",
|
102 |
+
"# Initialize metrics\n",
|
103 |
+
"f1_metric = MultilabelF1Score(num_labels=len(unique_terms), threshold=0.5)\n",
|
104 |
+
"f1_metric = f1_metric.to(device)\n",
|
105 |
+
"\n",
|
106 |
+
"num_epochs = NUM_EPOCHS\n",
|
107 |
+
"\n",
|
108 |
+
"for epoch in range(num_epochs):\n",
|
109 |
+
" total_loss = 0\n",
|
110 |
+
" for batch in train_loader:\n",
|
111 |
+
" optimizer.zero_grad()\n",
|
112 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
113 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
114 |
+
" labels = batch['labels'].to(device)\n",
|
115 |
+
"\n",
|
116 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
|
117 |
+
" loss = outputs.loss\n",
|
118 |
+
" accelerator.backward(loss)\n",
|
119 |
+
" optimizer.step()\n",
|
120 |
+
"\n",
|
121 |
+
" total_loss += loss.item()\n",
|
122 |
+
"\n",
|
123 |
+
" print(f'Epoch {epoch + 1}/{num_epochs}, Training loss: {total_loss/len(train_loader)}')\n",
|
124 |
+
"\n",
|
125 |
+
" model.eval()\n",
|
126 |
+
" predictions = []\n",
|
127 |
+
" true_labels_list = []\n",
|
128 |
+
" with torch.no_grad():\n",
|
129 |
+
" for batch in val_loader:\n",
|
130 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
131 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
132 |
+
" labels = batch['labels'].to(device)\n",
|
133 |
+
"\n",
|
134 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
|
135 |
+
" logits = outputs.logits\n",
|
136 |
+
" predictions.append(torch.sigmoid(logits))\n",
|
137 |
+
" true_labels_list.append(labels)\n",
|
138 |
+
"\n",
|
139 |
+
" predictions_tensor = torch.cat(predictions, dim=0).cpu().numpy()\n",
|
140 |
+
" true_labels_tensor = torch.cat(true_labels_list, dim=0).cpu().numpy()\n",
|
141 |
+
"\n",
|
142 |
+
" threshold = 0.5\n",
|
143 |
+
" predictions_bin = (predictions_tensor > threshold).astype(int)\n",
|
144 |
+
"\n",
|
145 |
+
" # Compute metrics\n",
|
146 |
+
" val_f1 = f1_metric(torch.tensor(predictions_tensor).to(device), torch.tensor(true_labels_tensor).to(device))\n",
|
147 |
+
" val_accuracy = accuracy_score(true_labels_tensor.flatten(), predictions_bin.flatten())\n",
|
148 |
+
" val_precision = precision_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n",
|
149 |
+
" val_recall = recall_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n",
|
150 |
+
" val_auc = average_precision_score(true_labels_tensor, predictions_tensor, average='micro')\n",
|
151 |
+
"\n",
|
152 |
+
" # Print metrics\n",
|
153 |
+
" print(f'Validation F1 Score: {val_f1}')\n",
|
154 |
+
" print(f'Validation Accuracy: {val_accuracy}')\n",
|
155 |
+
" print(f'Validation Precision: {val_precision}')\n",
|
156 |
+
" print(f'Validation Recall: {val_recall}')\n",
|
157 |
+
" print(f'Validation AUC: {val_auc}')\n",
|
158 |
+
"\n",
|
159 |
+
" timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n",
|
160 |
+
" model_path = f'./esm2_t6_8M_finetuned_cafa5_{timestamp}'\n",
|
161 |
+
" model.save_pretrained(model_path)\n",
|
162 |
+
" tokenizer.save_pretrained(model_path)\n",
|
163 |
+
"\n",
|
164 |
+
" print(f'Model checkpoint saved to {model_path}')\n"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"attachments": {},
|
169 |
+
"cell_type": "markdown",
|
170 |
+
"metadata": {},
|
171 |
+
"source": [
|
172 |
+
"## Save the Train/Validation Split Data"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": null,
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [],
|
180 |
+
"source": [
|
181 |
+
"import pickle\n",
|
182 |
+
"\n",
|
183 |
+
"# After you've created the train and validation splits:\n",
|
184 |
+
"data_splits = {\n",
|
185 |
+
" \"train_sequences\": train_sequences,\n",
|
186 |
+
" \"val_sequences\": val_sequences,\n",
|
187 |
+
" \"train_labels\": train_labels,\n",
|
188 |
+
" \"val_labels\": val_labels\n",
|
189 |
+
"}\n",
|
190 |
+
"\n",
|
191 |
+
"with open('data_splits.pkl', 'wb') as file:\n",
|
192 |
+
" pickle.dump(data_splits, file)\n"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"attachments": {},
|
197 |
+
"cell_type": "markdown",
|
198 |
+
"metadata": {},
|
199 |
+
"source": [
|
200 |
+
"## Reload the Data Later"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": null,
|
206 |
+
"metadata": {},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"import pickle\n",
|
210 |
+
"\n",
|
211 |
+
"# Load the data splits\n",
|
212 |
+
"with open('data_splits.pkl', 'rb') as file:\n",
|
213 |
+
" data_splits = pickle.load(file)\n",
|
214 |
+
"\n",
|
215 |
+
"train_sequences = data_splits[\"train_sequences\"]\n",
|
216 |
+
"val_sequences = data_splits[\"val_sequences\"]\n",
|
217 |
+
"train_labels = data_splits[\"train_labels\"]\n",
|
218 |
+
"val_labels = data_splits[\"val_labels\"]\n",
|
219 |
+
"\n",
|
220 |
+
"# Now, the rest of your code can proceed as it is, \n",
|
221 |
+
"# with the train and validation sets loaded from the pickle file."
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"attachments": {},
|
226 |
+
"cell_type": "markdown",
|
227 |
+
"metadata": {},
|
228 |
+
"source": [
|
229 |
+
"## Data Preprocessing"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": null,
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [],
|
237 |
+
"source": [
|
238 |
+
"import torch\n",
|
239 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
240 |
+
"from transformers import AutoTokenizer, EsmForSequenceClassification\n",
|
241 |
+
"from accelerate import Accelerator\n",
|
242 |
+
"from sklearn.model_selection import train_test_split\n",
|
243 |
+
"from torchmetrics.classification import MultilabelF1Score\n",
|
244 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, average_precision_score\n",
|
245 |
+
"import datetime\n",
|
246 |
+
"import pandas as pd\n",
|
247 |
+
"\n",
|
248 |
+
"# Load the data\n",
|
249 |
+
"data = pd.read_csv(\"C:/Users/OWO/Desktop/amelie_vscode/cafa5/data/merged_protein_data.tsv\", sep=\"\\t\")\n",
|
250 |
+
"# Use only the first 100 entries\n",
|
251 |
+
"data = data.head(100)\n",
|
252 |
+
"\n",
|
253 |
+
"# Initialize the accelerator\n",
|
254 |
+
"accelerator = Accelerator()\n",
|
255 |
+
"device = accelerator.device\n",
|
256 |
+
"\n",
|
257 |
+
"# Data Preprocessing\n",
|
258 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n",
|
259 |
+
"MAX_LENGTH = tokenizer.model_max_length\n",
|
260 |
+
"\n",
|
261 |
+
"class ProteinDataset(Dataset):\n",
|
262 |
+
" def __init__(self, sequences, labels):\n",
|
263 |
+
" self.sequences = sequences\n",
|
264 |
+
" self.labels = labels\n",
|
265 |
+
"\n",
|
266 |
+
" def __len__(self):\n",
|
267 |
+
" return len(self.sequences)\n",
|
268 |
+
"\n",
|
269 |
+
" def __getitem__(self, idx):\n",
|
270 |
+
" sequence = self.sequences[idx]\n",
|
271 |
+
" label = self.labels[idx]\n",
|
272 |
+
" encoding = tokenizer(sequence, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=MAX_LENGTH)\n",
|
273 |
+
" return {\n",
|
274 |
+
" 'input_ids': encoding['input_ids'].flatten(),\n",
|
275 |
+
" 'attention_mask': encoding['attention_mask'].flatten(),\n",
|
276 |
+
" 'labels': torch.tensor(label, dtype=torch.float)\n",
|
277 |
+
" }\n",
|
278 |
+
"\n",
|
279 |
+
"def encode_labels(go_terms, unique_terms):\n",
|
280 |
+
" encoded = []\n",
|
281 |
+
" for terms in go_terms:\n",
|
282 |
+
" encoding = [1 if term in terms else 0 for term in unique_terms]\n",
|
283 |
+
" encoded.append(encoding)\n",
|
284 |
+
" return encoded\n",
|
285 |
+
"\n",
|
286 |
+
"# train_sequences, val_sequences, train_labels, val_labels = train_test_split(data['sequence'], data['term'], test_size=0.1)\n",
|
287 |
+
"\n",
|
288 |
+
"# Reset the indices\n",
|
289 |
+
"# train_sequences = train_sequences.reset_index(drop=True)\n",
|
290 |
+
"# val_sequences = val_sequences.reset_index(drop=True)\n",
|
291 |
+
"# train_labels = train_labels.reset_index(drop=True)\n",
|
292 |
+
"# val_labels = val_labels.reset_index(drop=True)\n",
|
293 |
+
"\n",
|
294 |
+
"unique_terms = list(set(term for sublist in data['term'] for term in sublist))\n",
|
295 |
+
"train_labels_encoded = encode_labels(train_labels, unique_terms)\n",
|
296 |
+
"val_labels_encoded = encode_labels(val_labels, unique_terms)\n",
|
297 |
+
"\n",
|
298 |
+
"train_dataset = ProteinDataset(train_sequences, train_labels_encoded)\n",
|
299 |
+
"val_dataset = ProteinDataset(val_sequences, val_labels_encoded)\n",
|
300 |
+
"\n",
|
301 |
+
"train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)\n",
|
302 |
+
"val_loader = DataLoader(val_dataset, batch_size=2)"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"attachments": {},
|
307 |
+
"cell_type": "markdown",
|
308 |
+
"metadata": {},
|
309 |
+
"source": [
|
310 |
+
"## Fine-tune with LoRA"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"cell_type": "code",
|
315 |
+
"execution_count": null,
|
316 |
+
"metadata": {},
|
317 |
+
"outputs": [],
|
318 |
+
"source": [
|
319 |
+
"from collections import Counter\n",
|
320 |
+
"from peft import get_peft_config, get_peft_model, LoraConfig\n",
|
321 |
+
"import datetime\n",
|
322 |
+
"from sklearn.metrics import accuracy_score, precision_score, recall_score, hamming_loss, average_precision_score\n",
|
323 |
+
"from torchmetrics.classification import MultilabelF1Score\n",
|
324 |
+
"\n",
|
325 |
+
"# Constants\n",
|
326 |
+
"MODEL_NAME = \"facebook/esm2_t6_8M_UR50D\" # Replace with your trained model above\n",
|
327 |
+
"BATCH_SIZE = 4\n",
|
328 |
+
"NUM_EPOCHS = 7\n",
|
329 |
+
"LR = 3e-5\n",
|
330 |
+
"\n",
|
331 |
+
"# Initialize model with LoRA\n",
|
332 |
+
"peft_config = LoraConfig(\n",
|
333 |
+
" task_type=\"SEQ_CLS\", \n",
|
334 |
+
" inference_mode=False, \n",
|
335 |
+
" r=16, \n",
|
336 |
+
" bias=\"none\",\n",
|
337 |
+
" lora_alpha=16, \n",
|
338 |
+
" lora_dropout=0.1, \n",
|
339 |
+
" target_modules=[\"query\", \"key\", \"value\"]\n",
|
340 |
+
")\n",
|
341 |
+
"\n",
|
342 |
+
"base_model = EsmForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(unique_terms), problem_type=\"multi_label_classification\")\n",
|
343 |
+
"model = get_peft_model(base_model, peft_config)\n",
|
344 |
+
"model = model.to(accelerator.device)\n",
|
345 |
+
"\n",
|
346 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n",
|
347 |
+
"optimizer, model = accelerator.prepare(optimizer, model)\n",
|
348 |
+
"\n",
|
349 |
+
"f1_metric = MultilabelF1Score(num_labels=len(unique_terms), threshold=0.5)\n",
|
350 |
+
"f1_metric = f1_metric.to(device)\n",
|
351 |
+
"\n",
|
352 |
+
"# Compute Class Weights\n",
|
353 |
+
"def compute_class_weights(terms, term_to_id):\n",
|
354 |
+
" all_terms = [term for terms_list in terms for term in terms_list]\n",
|
355 |
+
" term_counts = Counter(all_terms)\n",
|
356 |
+
" total_terms = sum(term_counts.values())\n",
|
357 |
+
" class_weights = {term: total_terms / count for term, count in term_counts.items()}\n",
|
358 |
+
" weights = torch.tensor([class_weights[term] for term in term_to_id.keys()], dtype=torch.float)\n",
|
359 |
+
" normalized_weights = weights / weights.sum()\n",
|
360 |
+
" return normalized_weights\n",
|
361 |
+
"\n",
|
362 |
+
"term_to_id = {term: idx for idx, term in enumerate(unique_terms)}\n",
|
363 |
+
"all_terms_combined = train_labels.tolist() + val_labels.tolist()\n",
|
364 |
+
"weights = compute_class_weights(all_terms_combined, term_to_id)\n",
|
365 |
+
"weights = weights.to(accelerator.device)\n",
|
366 |
+
"loss_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)\n",
|
367 |
+
"\n",
|
368 |
+
"# Training loop\n",
|
369 |
+
"for epoch in range(NUM_EPOCHS):\n",
|
370 |
+
" # Training Phase\n",
|
371 |
+
" model.train()\n",
|
372 |
+
" total_train_loss = 0\n",
|
373 |
+
" for batch in train_loader:\n",
|
374 |
+
" optimizer.zero_grad()\n",
|
375 |
+
" input_ids = batch['input_ids'].to(accelerator.device)\n",
|
376 |
+
" attention_mask = batch['attention_mask'].to(accelerator.device)\n",
|
377 |
+
" labels = batch['labels'].to(accelerator.device)\n",
|
378 |
+
"\n",
|
379 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
|
380 |
+
" logits = outputs.logits\n",
|
381 |
+
" loss = loss_criterion(logits, labels)\n",
|
382 |
+
" accelerator.backward(loss)\n",
|
383 |
+
" optimizer.step()\n",
|
384 |
+
"\n",
|
385 |
+
" total_train_loss += loss.item()\n",
|
386 |
+
"\n",
|
387 |
+
" avg_train_loss = total_train_loss / len(train_loader)\n",
|
388 |
+
"\n",
|
389 |
+
" # Validation Phase\n",
|
390 |
+
" model.eval()\n",
|
391 |
+
" total_val_loss = 0\n",
|
392 |
+
" predictions = []\n",
|
393 |
+
" true_labels = []\n",
|
394 |
+
" with torch.no_grad():\n",
|
395 |
+
" for batch in val_loader:\n",
|
396 |
+
" input_ids = batch['input_ids'].to(accelerator.device)\n",
|
397 |
+
" attention_mask = batch['attention_mask'].to(accelerator.device)\n",
|
398 |
+
" labels = batch['labels'].to(accelerator.device)\n",
|
399 |
+
"\n",
|
400 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
|
401 |
+
" logits = outputs.logits\n",
|
402 |
+
" loss = loss_criterion(logits, labels)\n",
|
403 |
+
"\n",
|
404 |
+
" total_val_loss += loss.item()\n",
|
405 |
+
" predictions.append(torch.sigmoid(logits).detach())\n",
|
406 |
+
" true_labels.append(labels.detach())\n",
|
407 |
+
"\n",
|
408 |
+
"\n",
|
409 |
+
" avg_val_loss = total_val_loss / len(val_loader)\n",
|
410 |
+
" \n",
|
411 |
+
" predictions_tensor = torch.cat(predictions, dim=0).cpu().numpy()\n",
|
412 |
+
" true_labels_tensor = torch.cat(true_labels, dim=0).cpu().numpy()\n",
|
413 |
+
"\n",
|
414 |
+
" threshold = 0.5\n",
|
415 |
+
" predictions_bin = (predictions_tensor > threshold).astype(int)\n",
|
416 |
+
"\n",
|
417 |
+
" val_f1 = f1_metric(torch.tensor(predictions_tensor).to(device), torch.tensor(true_labels_tensor).to(device))\n",
|
418 |
+
" val_accuracy = accuracy_score(true_labels_tensor.flatten(), predictions_bin.flatten())\n",
|
419 |
+
" val_precision = precision_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n",
|
420 |
+
" val_recall = recall_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n",
|
421 |
+
" val_auc = average_precision_score(true_labels_tensor, predictions_tensor, average='micro')\n",
|
422 |
+
"\n",
|
423 |
+
" print(f\"Epoch {epoch + 1}/{NUM_EPOCHS} - Training Loss: {avg_train_loss:.4f} - Validation Loss: {avg_val_loss:.4f}\")\n",
|
424 |
+
" print(f\"Validation Metrics - Accuracy: {val_accuracy:.4f} - Precision (Micro): {val_precision:.4f} - Recall (Micro): {val_recall:.4f} - AUC: {val_auc:.4f} - F1 Score: {val_f1:.4f}\")\n",
|
425 |
+
"\n",
|
426 |
+
" timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n",
|
427 |
+
" # Save model and tokenizer. Note that Accelerator has a save method for models.\n",
|
428 |
+
" model_path = f'./esm2_t6_8M_cafa5_lora_{timestamp}'\n",
|
429 |
+
" model.save_pretrained(model_path)\n",
|
430 |
+
" tokenizer.save_pretrained(model_path)\n",
|
431 |
+
" model.base_model.save_pretrained(model_path)\n",
|
432 |
+
" print(f'Model checkpoint saved to {model_path}')\n"
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "code",
|
437 |
+
"execution_count": null,
|
438 |
+
"metadata": {},
|
439 |
+
"outputs": [],
|
440 |
+
"source": []
|
441 |
+
}
|
442 |
+
],
|
443 |
+
"metadata": {
|
444 |
+
"kernelspec": {
|
445 |
+
"display_name": "cafa_5",
|
446 |
+
"language": "python",
|
447 |
+
"name": "python3"
|
448 |
+
},
|
449 |
+
"language_info": {
|
450 |
+
"codemirror_mode": {
|
451 |
+
"name": "ipython",
|
452 |
+
"version": 3
|
453 |
+
},
|
454 |
+
"file_extension": ".py",
|
455 |
+
"mimetype": "text/x-python",
|
456 |
+
"name": "python",
|
457 |
+
"nbconvert_exporter": "python",
|
458 |
+
"pygments_lexer": "ipython3",
|
459 |
+
"version": "3.9.17"
|
460 |
+
},
|
461 |
+
"orig_nbformat": 4
|
462 |
+
},
|
463 |
+
"nbformat": 4,
|
464 |
+
"nbformat_minor": 2
|
465 |
+
}
|