{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data loaded successfully!\n", "Number of classes: 32\n", "Class names: ['Adrenocortical_carcinoma', 'Bladder_Urothelial_Carcinoma', 'Brain_Lower_Grade_Glioma', 'Breast_invasive_carcinoma', 'Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma', 'Cholangiocarcinoma', 'Colon_adenocarcinoma', 'Esophageal_carcinoma', 'Glioblastoma_multiforme', 'Head_and_Neck_squamous_cell_carcinoma', 'Kidney_Chromophobe', 'Kidney_renal_clear_cell_carcinoma', 'Kidney_renal_papillary_cell_carcinoma', 'Liver_hepatocellular_carcinoma', 'Lung_adenocarcinoma', 'Lung_squamous_cell_carcinoma', 'Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma', 'Mesothelioma', 'Ovarian_serous_cystadenocarcinoma', 'Pancreatic_adenocarcinoma', 'Pheochromocytoma_and_Paraganglioma', 'Prostate_adenocarcinoma', 'Rectum_adenocarcinoma', 'Sarcoma', 'Skin_Cutaneous_Melanoma', 'Stomach_adenocarcinoma', 'Testicular_Germ_Cell_Tumors', 'Thymoma', 'Thyroid_carcinoma', 'Uterine_Carcinosarcoma', 'Uterine_Corpus_Endometrial_Carcinoma', 'Uveal_Melanoma']\n", "ViTForCancerClassification(\n", " (vit): VisionTransformer(\n", " (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n", " (encoder): Encoder(\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (layers): Sequential(\n", " (encoder_layer_0): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_1): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_2): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_3): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_4): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_5): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_6): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_7): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_8): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_9): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_10): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (encoder_layer_11): EncoderBlock(\n", " (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (self_attention): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): MLPBlock(\n", " (0): Linear(in_features=768, out_features=3072, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.0, inplace=False)\n", " (3): Linear(in_features=3072, out_features=768, bias=True)\n", " (4): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " )\n", " (heads): Sequential(\n", " (head): Linear(in_features=768, out_features=32, bias=True)\n", " )\n", " )\n", ")\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler\n", "import torchvision\n", "from torchvision import datasets, transforms\n", "from torch.utils.data import Subset\n", "import numpy as np\n", "import os\n", "import pickle\n", "from tqdm.auto import tqdm\n", "from pathlib import Path\n", "from torchvision.models import vit_b_16, ViT_B_16_Weights\n", "\n", "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n", "\n", "# Paths to save the dataloaders and class information\n", "save_path = \"saved_objects\"\n", "class_info_path = os.path.join(save_path, 'class_info.pkl')\n", "train_dataloader_path = os.path.join(save_path, 'train_dataloader.pkl')\n", "test_dataloader_path = os.path.join(save_path, 'test_dataloader.pkl')\n", "\n", "# Create directory if not exists\n", "os.makedirs(save_path, exist_ok=True)\n", "\n", "# Function to load saved objects\n", "def load_saved_data():\n", " if os.path.exists(class_info_path) and os.path.exists(train_dataloader_path) and os.path.exists(test_dataloader_path):\n", " with open(class_info_path, 'rb') as f:\n", " class_info = pickle.load(f)\n", " total_samples = class_info['total_samples']\n", " class_weights = class_info['class_weights']\n", " sample_weights = class_info['sample_weights']\n", "\n", " with open(train_dataloader_path, 'rb') as f:\n", " train_dataloader = pickle.load(f)\n", "\n", " with open(test_dataloader_path, 'rb') as f:\n", " test_dataloader = pickle.load(f)\n", "\n", " print(\"Data loaded successfully!\")\n", " return total_samples, class_weights, sample_weights, train_dataloader, test_dataloader\n", " else:\n", " return None, None, None, None, None\n", "\n", "# Function to save objects\n", "def save_data(total_samples, class_weights, sample_weights, train_dataloader, test_dataloader):\n", " with open(class_info_path, 'wb') as f:\n", " pickle.dump({\n", " 'total_samples': total_samples,\n", " 'class_weights': class_weights,\n", " 'sample_weights': sample_weights\n", " }, f)\n", "\n", " with open(train_dataloader_path, 'wb') as f:\n", " pickle.dump(train_dataloader, f)\n", "\n", " with open(test_dataloader_path, 'wb') as f:\n", " pickle.dump(test_dataloader, f)\n", "\n", " print(\"Data saved successfully!\")\n", "\n", "# Define the ViT model\n", "class ViTForCancerClassification(nn.Module):\n", " def __init__(self, num_classes):\n", " super(ViTForCancerClassification, self).__init__()\n", " self.vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)\n", " \n", " # Get the input features of the classifier\n", " in_features = self.vit.heads.head.in_features # Access the head layer specifically\n", " \n", " # Replace the head with a new classification layer\n", " self.vit.heads.head = nn.Linear(in_features, num_classes)\n", " \n", " def forward(self, x):\n", " return self.vit(x)\n", "\n", "# Function to get attention weights\n", "def get_attention_weights(model, x):\n", " with torch.no_grad():\n", " outputs = model.vit._process_input(x)\n", " outputs = model.vit.encoder(outputs)\n", " return model.vit.encoder.layers[-1].self_attention.attention_weights\n", "\n", "# Try to load saved data\n", "total_samples, class_weights, sample_weights, train_dataloader, test_dataloader = load_saved_data()\n", "\n", "# If the data is not available, run preprocessing\n", "if total_samples is None:\n", " print(\"No saved data found. Running data preprocessing...\")\n", "\n", " # Data loading and preprocessing\n", " data_path = Path('TCGA')\n", " transform = transforms.Compose([\n", " transforms.Resize((224, 224)), # ViT typically expects 224x224 input\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", " ])\n", "\n", " full_dataset = datasets.ImageFolder(root=data_path, transform=transform)\n", " valid_indices = [i for i, (_, label) in enumerate(full_dataset.samples)]\n", " dataset = Subset(full_dataset, valid_indices)\n", "\n", " class_names = [name for name, idx in full_dataset.class_to_idx.items()]\n", " class_to_idx = {name: idx for name, idx in full_dataset.class_to_idx.items()}\n", " print(class_names, class_to_idx)\n", "\n", " # Calculate class weights\n", " class_counts = [0] * len(class_names)\n", " for _, label in dataset:\n", " class_counts[label] += 1\n", " total_samples = sum(class_counts)\n", " class_weights = [total_samples / (len(class_names) * count) for count in class_counts]\n", " sample_weights = [class_weights[label] for _, label in dataset]\n", "\n", " # Create WeightedRandomSampler\n", " sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)\n", "\n", " # Create data loaders\n", " BATCH_SIZE = 128\n", " train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)\n", " test_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)\n", "\n", " # Save the processed data for future use\n", " save_data(total_samples, class_weights, sample_weights, train_dataloader, test_dataloader)\n", "\n", "class_names = ['Adrenocortical_carcinoma', 'Bladder_Urothelial_Carcinoma', 'Brain_Lower_Grade_Glioma', 'Breast_invasive_carcinoma', 'Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma', 'Cholangiocarcinoma', 'Colon_adenocarcinoma', 'Esophageal_carcinoma', 'Glioblastoma_multiforme', 'Head_and_Neck_squamous_cell_carcinoma', 'Kidney_Chromophobe', 'Kidney_renal_clear_cell_carcinoma', 'Kidney_renal_papillary_cell_carcinoma', 'Liver_hepatocellular_carcinoma', 'Lung_adenocarcinoma', 'Lung_squamous_cell_carcinoma', 'Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma', 'Mesothelioma', 'Ovarian_serous_cystadenocarcinoma', 'Pancreatic_adenocarcinoma', 'Pheochromocytoma_and_Paraganglioma', 'Prostate_adenocarcinoma', 'Rectum_adenocarcinoma', 'Sarcoma', 'Skin_Cutaneous_Melanoma', 'Stomach_adenocarcinoma', 'Testicular_Germ_Cell_Tumors', 'Thymoma', 'Thyroid_carcinoma', 'Uterine_Carcinosarcoma', 'Uterine_Corpus_Endometrial_Carcinoma', 'Uveal_Melanoma']\n", "print(f\"Number of classes: {len(class_names)}\")\n", "print(f\"Class names: {class_names}\")\n", "\n", "# Model setup\n", "num_classes = len(class_names)\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "model = ViTForCancerClassification(num_classes).to(device)\n", "print(model)\n", "\n", "# Training setup\n", "torch.manual_seed(42)\n", "EPOCHS = 20\n", "class_weights_tensor = torch.FloatTensor(class_weights).to(device)\n", "loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "\n", "results = {\n", " 'train_loss': [], \n", " 'train_acc': [],\n", " 'test_loss': [],\n", " 'test_acc': []\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Define the checkpoint file (change to the correct path if necessary)\n", "checkpoint_path = 'vit_cancer_model_state_dict_X.pth' # Replace 'X' with the last saved epoch number\n", "\n", "# Load the saved model if it exists\n", "if os.path.exists(checkpoint_path):\n", " print(f\"Loading model from {checkpoint_path}\")\n", " model.load_state_dict(torch.load(checkpoint_path))\n", " start_epoch = int(checkpoint_path.split('_')[-1].split('.')[0]) + 1\n", "else:\n", " print(\"No checkpoint found, starting training from scratch.\")\n", " start_epoch = 0\n", "\n", "# Resume training\n", "for epoch in range(start_epoch, EPOCHS):\n", " print(f\"Epoch {epoch+1}/{EPOCHS}\")\n", " train_loss, train_acc = 0, 0\n", " model.train()\n", " for batch, (X, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):\n", " X, y = X.to(device), y.to(device)\n", " y_logits = model(X)\n", " y_pred_class = torch.argmax(torch.softmax(y_logits, dim=1), dim=1)\n", " loss = loss_fn(y_logits, y)\n", " train_acc += (y_pred_class == y).sum().item() / len(y)\n", " train_loss += loss.item()\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " train_loss /= len(train_dataloader)\n", " train_acc /= len(train_dataloader)\n", " \n", " results['train_loss'].append(train_loss)\n", " results['train_acc'].append(train_acc)\n", " \n", " model.eval()\n", " test_loss, test_acc = 0, 0\n", " with torch.inference_mode():\n", " for batch, (X, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):\n", " X, y = X.to(device), y.to(device)\n", " \n", " test_logits = model(X)\n", " test_pred_labels = test_logits.argmax(dim=1)\n", " loss = loss_fn(test_logits, y)\n", " test_acc += (test_pred_labels == y).sum().item() / len(y)\n", " test_loss += loss.item()\n", " \n", " test_loss /= len(test_dataloader)\n", " test_acc /= len(test_dataloader)\n", " print(f'Training loss: {train_loss:.5f} acc: {train_acc:.5f} | Testing loss: {test_loss:.5f} acc: {test_acc:.5f}')\n", " \n", " results['test_loss'].append(test_loss)\n", " results['test_acc'].append(test_acc)\n", " \n", " # Save the model checkpoint after every epoch\n", " torch.save(model.state_dict(), f'vit_cancer_model_state_dict_{epoch}.pth')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }