aoxo commited on
Commit
00f3169
1 Parent(s): 4f0c54b

Training script

Browse files
Files changed (1) hide show
  1. histology_vit.ipynb +451 -0
histology_vit.ipynb ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Data loaded successfully!\n",
13
+ "Number of classes: 32\n",
14
+ "Class names: ['Adrenocortical_carcinoma', 'Bladder_Urothelial_Carcinoma', 'Brain_Lower_Grade_Glioma', 'Breast_invasive_carcinoma', 'Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma', 'Cholangiocarcinoma', 'Colon_adenocarcinoma', 'Esophageal_carcinoma', 'Glioblastoma_multiforme', 'Head_and_Neck_squamous_cell_carcinoma', 'Kidney_Chromophobe', 'Kidney_renal_clear_cell_carcinoma', 'Kidney_renal_papillary_cell_carcinoma', 'Liver_hepatocellular_carcinoma', 'Lung_adenocarcinoma', 'Lung_squamous_cell_carcinoma', 'Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma', 'Mesothelioma', 'Ovarian_serous_cystadenocarcinoma', 'Pancreatic_adenocarcinoma', 'Pheochromocytoma_and_Paraganglioma', 'Prostate_adenocarcinoma', 'Rectum_adenocarcinoma', 'Sarcoma', 'Skin_Cutaneous_Melanoma', 'Stomach_adenocarcinoma', 'Testicular_Germ_Cell_Tumors', 'Thymoma', 'Thyroid_carcinoma', 'Uterine_Carcinosarcoma', 'Uterine_Corpus_Endometrial_Carcinoma', 'Uveal_Melanoma']\n",
15
+ "ViTForCancerClassification(\n",
16
+ " (vit): VisionTransformer(\n",
17
+ " (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
18
+ " (encoder): Encoder(\n",
19
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
20
+ " (layers): Sequential(\n",
21
+ " (encoder_layer_0): EncoderBlock(\n",
22
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
23
+ " (self_attention): MultiheadAttention(\n",
24
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
25
+ " )\n",
26
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
27
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
28
+ " (mlp): MLPBlock(\n",
29
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
30
+ " (1): GELU(approximate='none')\n",
31
+ " (2): Dropout(p=0.0, inplace=False)\n",
32
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
33
+ " (4): Dropout(p=0.0, inplace=False)\n",
34
+ " )\n",
35
+ " )\n",
36
+ " (encoder_layer_1): EncoderBlock(\n",
37
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
38
+ " (self_attention): MultiheadAttention(\n",
39
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
40
+ " )\n",
41
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
42
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
43
+ " (mlp): MLPBlock(\n",
44
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
45
+ " (1): GELU(approximate='none')\n",
46
+ " (2): Dropout(p=0.0, inplace=False)\n",
47
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
48
+ " (4): Dropout(p=0.0, inplace=False)\n",
49
+ " )\n",
50
+ " )\n",
51
+ " (encoder_layer_2): EncoderBlock(\n",
52
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
53
+ " (self_attention): MultiheadAttention(\n",
54
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
55
+ " )\n",
56
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
57
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
58
+ " (mlp): MLPBlock(\n",
59
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
60
+ " (1): GELU(approximate='none')\n",
61
+ " (2): Dropout(p=0.0, inplace=False)\n",
62
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
63
+ " (4): Dropout(p=0.0, inplace=False)\n",
64
+ " )\n",
65
+ " )\n",
66
+ " (encoder_layer_3): EncoderBlock(\n",
67
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
68
+ " (self_attention): MultiheadAttention(\n",
69
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
70
+ " )\n",
71
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
72
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
73
+ " (mlp): MLPBlock(\n",
74
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
75
+ " (1): GELU(approximate='none')\n",
76
+ " (2): Dropout(p=0.0, inplace=False)\n",
77
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
78
+ " (4): Dropout(p=0.0, inplace=False)\n",
79
+ " )\n",
80
+ " )\n",
81
+ " (encoder_layer_4): EncoderBlock(\n",
82
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
83
+ " (self_attention): MultiheadAttention(\n",
84
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
85
+ " )\n",
86
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
87
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
88
+ " (mlp): MLPBlock(\n",
89
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
90
+ " (1): GELU(approximate='none')\n",
91
+ " (2): Dropout(p=0.0, inplace=False)\n",
92
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
93
+ " (4): Dropout(p=0.0, inplace=False)\n",
94
+ " )\n",
95
+ " )\n",
96
+ " (encoder_layer_5): EncoderBlock(\n",
97
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
98
+ " (self_attention): MultiheadAttention(\n",
99
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
100
+ " )\n",
101
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
102
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
103
+ " (mlp): MLPBlock(\n",
104
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
105
+ " (1): GELU(approximate='none')\n",
106
+ " (2): Dropout(p=0.0, inplace=False)\n",
107
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
108
+ " (4): Dropout(p=0.0, inplace=False)\n",
109
+ " )\n",
110
+ " )\n",
111
+ " (encoder_layer_6): EncoderBlock(\n",
112
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
113
+ " (self_attention): MultiheadAttention(\n",
114
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
115
+ " )\n",
116
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
117
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
118
+ " (mlp): MLPBlock(\n",
119
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
120
+ " (1): GELU(approximate='none')\n",
121
+ " (2): Dropout(p=0.0, inplace=False)\n",
122
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
123
+ " (4): Dropout(p=0.0, inplace=False)\n",
124
+ " )\n",
125
+ " )\n",
126
+ " (encoder_layer_7): EncoderBlock(\n",
127
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
128
+ " (self_attention): MultiheadAttention(\n",
129
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
130
+ " )\n",
131
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
132
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
133
+ " (mlp): MLPBlock(\n",
134
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
135
+ " (1): GELU(approximate='none')\n",
136
+ " (2): Dropout(p=0.0, inplace=False)\n",
137
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
138
+ " (4): Dropout(p=0.0, inplace=False)\n",
139
+ " )\n",
140
+ " )\n",
141
+ " (encoder_layer_8): EncoderBlock(\n",
142
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
143
+ " (self_attention): MultiheadAttention(\n",
144
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
145
+ " )\n",
146
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
147
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
148
+ " (mlp): MLPBlock(\n",
149
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
150
+ " (1): GELU(approximate='none')\n",
151
+ " (2): Dropout(p=0.0, inplace=False)\n",
152
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
153
+ " (4): Dropout(p=0.0, inplace=False)\n",
154
+ " )\n",
155
+ " )\n",
156
+ " (encoder_layer_9): EncoderBlock(\n",
157
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
158
+ " (self_attention): MultiheadAttention(\n",
159
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
160
+ " )\n",
161
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
162
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
163
+ " (mlp): MLPBlock(\n",
164
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
165
+ " (1): GELU(approximate='none')\n",
166
+ " (2): Dropout(p=0.0, inplace=False)\n",
167
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
168
+ " (4): Dropout(p=0.0, inplace=False)\n",
169
+ " )\n",
170
+ " )\n",
171
+ " (encoder_layer_10): EncoderBlock(\n",
172
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
173
+ " (self_attention): MultiheadAttention(\n",
174
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
175
+ " )\n",
176
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
177
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
178
+ " (mlp): MLPBlock(\n",
179
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
180
+ " (1): GELU(approximate='none')\n",
181
+ " (2): Dropout(p=0.0, inplace=False)\n",
182
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
183
+ " (4): Dropout(p=0.0, inplace=False)\n",
184
+ " )\n",
185
+ " )\n",
186
+ " (encoder_layer_11): EncoderBlock(\n",
187
+ " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
188
+ " (self_attention): MultiheadAttention(\n",
189
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
190
+ " )\n",
191
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
192
+ " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
193
+ " (mlp): MLPBlock(\n",
194
+ " (0): Linear(in_features=768, out_features=3072, bias=True)\n",
195
+ " (1): GELU(approximate='none')\n",
196
+ " (2): Dropout(p=0.0, inplace=False)\n",
197
+ " (3): Linear(in_features=3072, out_features=768, bias=True)\n",
198
+ " (4): Dropout(p=0.0, inplace=False)\n",
199
+ " )\n",
200
+ " )\n",
201
+ " )\n",
202
+ " (ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
203
+ " )\n",
204
+ " (heads): Sequential(\n",
205
+ " (head): Linear(in_features=768, out_features=32, bias=True)\n",
206
+ " )\n",
207
+ " )\n",
208
+ ")\n"
209
+ ]
210
+ }
211
+ ],
212
+ "source": [
213
+ "import torch\n",
214
+ "import torch.nn as nn\n",
215
+ "from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler\n",
216
+ "import torchvision\n",
217
+ "from torchvision import datasets, transforms\n",
218
+ "from torch.utils.data import Subset\n",
219
+ "import numpy as np\n",
220
+ "import os\n",
221
+ "import pickle\n",
222
+ "from tqdm.auto import tqdm\n",
223
+ "from pathlib import Path\n",
224
+ "from torchvision.models import vit_b_16, ViT_B_16_Weights\n",
225
+ "\n",
226
+ "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
227
+ "\n",
228
+ "# Paths to save the dataloaders and class information\n",
229
+ "save_path = \"saved_objects\"\n",
230
+ "class_info_path = os.path.join(save_path, 'class_info.pkl')\n",
231
+ "train_dataloader_path = os.path.join(save_path, 'train_dataloader.pkl')\n",
232
+ "test_dataloader_path = os.path.join(save_path, 'test_dataloader.pkl')\n",
233
+ "\n",
234
+ "# Create directory if not exists\n",
235
+ "os.makedirs(save_path, exist_ok=True)\n",
236
+ "\n",
237
+ "# Function to load saved objects\n",
238
+ "def load_saved_data():\n",
239
+ " if os.path.exists(class_info_path) and os.path.exists(train_dataloader_path) and os.path.exists(test_dataloader_path):\n",
240
+ " with open(class_info_path, 'rb') as f:\n",
241
+ " class_info = pickle.load(f)\n",
242
+ " total_samples = class_info['total_samples']\n",
243
+ " class_weights = class_info['class_weights']\n",
244
+ " sample_weights = class_info['sample_weights']\n",
245
+ "\n",
246
+ " with open(train_dataloader_path, 'rb') as f:\n",
247
+ " train_dataloader = pickle.load(f)\n",
248
+ "\n",
249
+ " with open(test_dataloader_path, 'rb') as f:\n",
250
+ " test_dataloader = pickle.load(f)\n",
251
+ "\n",
252
+ " print(\"Data loaded successfully!\")\n",
253
+ " return total_samples, class_weights, sample_weights, train_dataloader, test_dataloader\n",
254
+ " else:\n",
255
+ " return None, None, None, None, None\n",
256
+ "\n",
257
+ "# Function to save objects\n",
258
+ "def save_data(total_samples, class_weights, sample_weights, train_dataloader, test_dataloader):\n",
259
+ " with open(class_info_path, 'wb') as f:\n",
260
+ " pickle.dump({\n",
261
+ " 'total_samples': total_samples,\n",
262
+ " 'class_weights': class_weights,\n",
263
+ " 'sample_weights': sample_weights\n",
264
+ " }, f)\n",
265
+ "\n",
266
+ " with open(train_dataloader_path, 'wb') as f:\n",
267
+ " pickle.dump(train_dataloader, f)\n",
268
+ "\n",
269
+ " with open(test_dataloader_path, 'wb') as f:\n",
270
+ " pickle.dump(test_dataloader, f)\n",
271
+ "\n",
272
+ " print(\"Data saved successfully!\")\n",
273
+ "\n",
274
+ "# Define the ViT model\n",
275
+ "class ViTForCancerClassification(nn.Module):\n",
276
+ " def __init__(self, num_classes):\n",
277
+ " super(ViTForCancerClassification, self).__init__()\n",
278
+ " self.vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)\n",
279
+ " \n",
280
+ " # Get the input features of the classifier\n",
281
+ " in_features = self.vit.heads.head.in_features # Access the head layer specifically\n",
282
+ " \n",
283
+ " # Replace the head with a new classification layer\n",
284
+ " self.vit.heads.head = nn.Linear(in_features, num_classes)\n",
285
+ " \n",
286
+ " def forward(self, x):\n",
287
+ " return self.vit(x)\n",
288
+ "\n",
289
+ "# Function to get attention weights\n",
290
+ "def get_attention_weights(model, x):\n",
291
+ " with torch.no_grad():\n",
292
+ " outputs = model.vit._process_input(x)\n",
293
+ " outputs = model.vit.encoder(outputs)\n",
294
+ " return model.vit.encoder.layers[-1].self_attention.attention_weights\n",
295
+ "\n",
296
+ "# Try to load saved data\n",
297
+ "total_samples, class_weights, sample_weights, train_dataloader, test_dataloader = load_saved_data()\n",
298
+ "\n",
299
+ "# If the data is not available, run preprocessing\n",
300
+ "if total_samples is None:\n",
301
+ " print(\"No saved data found. Running data preprocessing...\")\n",
302
+ "\n",
303
+ " # Data loading and preprocessing\n",
304
+ " data_path = Path('TCGA')\n",
305
+ " transform = transforms.Compose([\n",
306
+ " transforms.Resize((224, 224)), # ViT typically expects 224x224 input\n",
307
+ " transforms.ToTensor(),\n",
308
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
309
+ " ])\n",
310
+ "\n",
311
+ " full_dataset = datasets.ImageFolder(root=data_path, transform=transform)\n",
312
+ " valid_indices = [i for i, (_, label) in enumerate(full_dataset.samples)]\n",
313
+ " dataset = Subset(full_dataset, valid_indices)\n",
314
+ "\n",
315
+ " class_names = [name for name, idx in full_dataset.class_to_idx.items()]\n",
316
+ " class_to_idx = {name: idx for name, idx in full_dataset.class_to_idx.items()}\n",
317
+ " print(class_names, class_to_idx)\n",
318
+ "\n",
319
+ " # Calculate class weights\n",
320
+ " class_counts = [0] * len(class_names)\n",
321
+ " for _, label in dataset:\n",
322
+ " class_counts[label] += 1\n",
323
+ " total_samples = sum(class_counts)\n",
324
+ " class_weights = [total_samples / (len(class_names) * count) for count in class_counts]\n",
325
+ " sample_weights = [class_weights[label] for _, label in dataset]\n",
326
+ "\n",
327
+ " # Create WeightedRandomSampler\n",
328
+ " sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)\n",
329
+ "\n",
330
+ " # Create data loaders\n",
331
+ " BATCH_SIZE = 128\n",
332
+ " train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)\n",
333
+ " test_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
334
+ "\n",
335
+ " # Save the processed data for future use\n",
336
+ " save_data(total_samples, class_weights, sample_weights, train_dataloader, test_dataloader)\n",
337
+ "\n",
338
+ "class_names = ['Adrenocortical_carcinoma', 'Bladder_Urothelial_Carcinoma', 'Brain_Lower_Grade_Glioma', 'Breast_invasive_carcinoma', 'Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma', 'Cholangiocarcinoma', 'Colon_adenocarcinoma', 'Esophageal_carcinoma', 'Glioblastoma_multiforme', 'Head_and_Neck_squamous_cell_carcinoma', 'Kidney_Chromophobe', 'Kidney_renal_clear_cell_carcinoma', 'Kidney_renal_papillary_cell_carcinoma', 'Liver_hepatocellular_carcinoma', 'Lung_adenocarcinoma', 'Lung_squamous_cell_carcinoma', 'Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma', 'Mesothelioma', 'Ovarian_serous_cystadenocarcinoma', 'Pancreatic_adenocarcinoma', 'Pheochromocytoma_and_Paraganglioma', 'Prostate_adenocarcinoma', 'Rectum_adenocarcinoma', 'Sarcoma', 'Skin_Cutaneous_Melanoma', 'Stomach_adenocarcinoma', 'Testicular_Germ_Cell_Tumors', 'Thymoma', 'Thyroid_carcinoma', 'Uterine_Carcinosarcoma', 'Uterine_Corpus_Endometrial_Carcinoma', 'Uveal_Melanoma']\n",
339
+ "print(f\"Number of classes: {len(class_names)}\")\n",
340
+ "print(f\"Class names: {class_names}\")\n",
341
+ "\n",
342
+ "# Model setup\n",
343
+ "num_classes = len(class_names)\n",
344
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
345
+ "model = ViTForCancerClassification(num_classes).to(device)\n",
346
+ "print(model)\n",
347
+ "\n",
348
+ "# Training setup\n",
349
+ "torch.manual_seed(42)\n",
350
+ "EPOCHS = 20\n",
351
+ "class_weights_tensor = torch.FloatTensor(class_weights).to(device)\n",
352
+ "loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)\n",
353
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
354
+ "\n",
355
+ "results = {\n",
356
+ " 'train_loss': [], \n",
357
+ " 'train_acc': [],\n",
358
+ " 'test_loss': [],\n",
359
+ " 'test_acc': []\n",
360
+ "}"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "metadata": {},
367
+ "outputs": [],
368
+ "source": [
369
+ "import torch\n",
370
+ "\n",
371
+ "# Define the checkpoint file (change to the correct path if necessary)\n",
372
+ "checkpoint_path = 'vit_cancer_model_state_dict_X.pth' # Replace 'X' with the last saved epoch number\n",
373
+ "\n",
374
+ "# Load the saved model if it exists\n",
375
+ "if os.path.exists(checkpoint_path):\n",
376
+ " print(f\"Loading model from {checkpoint_path}\")\n",
377
+ " model.load_state_dict(torch.load(checkpoint_path))\n",
378
+ " start_epoch = int(checkpoint_path.split('_')[-1].split('.')[0]) + 1\n",
379
+ "else:\n",
380
+ " print(\"No checkpoint found, starting training from scratch.\")\n",
381
+ " start_epoch = 0\n",
382
+ "\n",
383
+ "# Resume training\n",
384
+ "for epoch in range(start_epoch, EPOCHS):\n",
385
+ " print(f\"Epoch {epoch+1}/{EPOCHS}\")\n",
386
+ " train_loss, train_acc = 0, 0\n",
387
+ " model.train()\n",
388
+ " for batch, (X, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):\n",
389
+ " X, y = X.to(device), y.to(device)\n",
390
+ " y_logits = model(X)\n",
391
+ " y_pred_class = torch.argmax(torch.softmax(y_logits, dim=1), dim=1)\n",
392
+ " loss = loss_fn(y_logits, y)\n",
393
+ " train_acc += (y_pred_class == y).sum().item() / len(y)\n",
394
+ " train_loss += loss.item()\n",
395
+ " \n",
396
+ " optimizer.zero_grad()\n",
397
+ " loss.backward()\n",
398
+ " optimizer.step()\n",
399
+ " \n",
400
+ " train_loss /= len(train_dataloader)\n",
401
+ " train_acc /= len(train_dataloader)\n",
402
+ " \n",
403
+ " results['train_loss'].append(train_loss)\n",
404
+ " results['train_acc'].append(train_acc)\n",
405
+ " \n",
406
+ " model.eval()\n",
407
+ " test_loss, test_acc = 0, 0\n",
408
+ " with torch.inference_mode():\n",
409
+ " for batch, (X, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):\n",
410
+ " X, y = X.to(device), y.to(device)\n",
411
+ " \n",
412
+ " test_logits = model(X)\n",
413
+ " test_pred_labels = test_logits.argmax(dim=1)\n",
414
+ " loss = loss_fn(test_logits, y)\n",
415
+ " test_acc += (test_pred_labels == y).sum().item() / len(y)\n",
416
+ " test_loss += loss.item()\n",
417
+ " \n",
418
+ " test_loss /= len(test_dataloader)\n",
419
+ " test_acc /= len(test_dataloader)\n",
420
+ " print(f'Training loss: {train_loss:.5f} acc: {train_acc:.5f} | Testing loss: {test_loss:.5f} acc: {test_acc:.5f}')\n",
421
+ " \n",
422
+ " results['test_loss'].append(test_loss)\n",
423
+ " results['test_acc'].append(test_acc)\n",
424
+ " \n",
425
+ " # Save the model checkpoint after every epoch\n",
426
+ " torch.save(model.state_dict(), f'vit_cancer_model_state_dict_{epoch}.pth')"
427
+ ]
428
+ }
429
+ ],
430
+ "metadata": {
431
+ "kernelspec": {
432
+ "display_name": "Python 3",
433
+ "language": "python",
434
+ "name": "python3"
435
+ },
436
+ "language_info": {
437
+ "codemirror_mode": {
438
+ "name": "ipython",
439
+ "version": 3
440
+ },
441
+ "file_extension": ".py",
442
+ "mimetype": "text/x-python",
443
+ "name": "python",
444
+ "nbconvert_exporter": "python",
445
+ "pygments_lexer": "ipython3",
446
+ "version": "3.12.3"
447
+ }
448
+ },
449
+ "nbformat": 4,
450
+ "nbformat_minor": 2
451
+ }