In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Subset
import numpy as np
import os
import pickle
from tqdm.auto import tqdm
from pathlib import Path
from torchvision.models import vit_b_16, ViT_B_16_Weights

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Paths to save the dataloaders and class information
save_path = "saved_objects"
class_info_path = os.path.join(save_path, 'class_info.pkl')
train_dataloader_path = os.path.join(save_path, 'train_dataloader.pkl')
test_dataloader_path = os.path.join(save_path, 'test_dataloader.pkl')

# Create directory if not exists
os.makedirs(save_path, exist_ok=True)

# Function to load saved objects
def load_saved_data():
 if os.path.exists(class_info_path) and os.path.exists(train_dataloader_path) and os.path.exists(test_dataloader_path):
 with open(class_info_path, 'rb') as f:
 class_info = pickle.load(f)
 total_samples = class_info['total_samples']
 class_weights = class_info['class_weights']
 sample_weights = class_info['sample_weights']

 with open(train_dataloader_path, 'rb') as f:
 train_dataloader = pickle.load(f)

 with open(test_dataloader_path, 'rb') as f:
 test_dataloader = pickle.load(f)

 print("Data loaded successfully!")
 return total_samples, class_weights, sample_weights, train_dataloader, test_dataloader
 else:
 return None, None, None, None, None

# Function to save objects
def save_data(total_samples, class_weights, sample_weights, train_dataloader, test_dataloader):
 with open(class_info_path, 'wb') as f:
 pickle.dump({
 'total_samples': total_samples,
 'class_weights': class_weights,
 'sample_weights': sample_weights
 }, f)

 with open(train_dataloader_path, 'wb') as f:
 pickle.dump(train_dataloader, f)

 with open(test_dataloader_path, 'wb') as f:
 pickle.dump(test_dataloader, f)

 print("Data saved successfully!")

# Define the ViT model
class ViTForCancerClassification(nn.Module):
 def __init__(self, num_classes):
 super(ViTForCancerClassification, self).__init__()
 self.vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
 
 # Get the input features of the classifier
 in_features = self.vit.heads.head.in_features # Access the head layer specifically
 
 # Replace the head with a new classification layer
 self.vit.heads.head = nn.Linear(in_features, num_classes)
 
 def forward(self, x):
 return self.vit(x)

# Function to get attention weights
def get_attention_weights(model, x):
 with torch.no_grad():
 outputs = model.vit._process_input(x)
 outputs = model.vit.encoder(outputs)
 return model.vit.encoder.layers[-1].self_attention.attention_weights

# Try to load saved data
total_samples, class_weights, sample_weights, train_dataloader, test_dataloader = load_saved_data()

# If the data is not available, run preprocessing
if total_samples is None:
 print("No saved data found. Running data preprocessing...")

 # Data loading and preprocessing
 data_path = Path('TCGA')
 transform = transforms.Compose([
 transforms.Resize((224, 224)), # ViT typically expects 224x224 input
 transforms.ToTensor(),
 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 ])

 full_dataset = datasets.ImageFolder(root=data_path, transform=transform)
 valid_indices = [i for i, (_, label) in enumerate(full_dataset.samples)]
 dataset = Subset(full_dataset, valid_indices)

 class_names = [name for name, idx in full_dataset.class_to_idx.items()]
 class_to_idx = {name: idx for name, idx in full_dataset.class_to_idx.items()}
 print(class_names, class_to_idx)

 # Calculate class weights
 class_counts = [0] * len(class_names)
 for _, label in dataset:
 class_counts[label] += 1
 total_samples = sum(class_counts)
 class_weights = [total_samples / (len(class_names) * count) for count in class_counts]
 sample_weights = [class_weights[label] for _, label in dataset]

 # Create WeightedRandomSampler
 sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

 # Create data loaders
 BATCH_SIZE = 128
 train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
 test_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

 # Save the processed data for future use
 save_data(total_samples, class_weights, sample_weights, train_dataloader, test_dataloader)

class_names = ['Adrenocortical_carcinoma', 'Bladder_Urothelial_Carcinoma', 'Brain_Lower_Grade_Glioma', 'Breast_invasive_carcinoma', 'Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma', 'Cholangiocarcinoma', 'Colon_adenocarcinoma', 'Esophageal_carcinoma', 'Glioblastoma_multiforme', 'Head_and_Neck_squamous_cell_carcinoma', 'Kidney_Chromophobe', 'Kidney_renal_clear_cell_carcinoma', 'Kidney_renal_papillary_cell_carcinoma', 'Liver_hepatocellular_carcinoma', 'Lung_adenocarcinoma', 'Lung_squamous_cell_carcinoma', 'Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma', 'Mesothelioma', 'Ovarian_serous_cystadenocarcinoma', 'Pancreatic_adenocarcinoma', 'Pheochromocytoma_and_Paraganglioma', 'Prostate_adenocarcinoma', 'Rectum_adenocarcinoma', 'Sarcoma', 'Skin_Cutaneous_Melanoma', 'Stomach_adenocarcinoma', 'Testicular_Germ_Cell_Tumors', 'Thymoma', 'Thyroid_carcinoma', 'Uterine_Carcinosarcoma', 'Uterine_Corpus_Endometrial_Carcinoma', 'Uveal_Melanoma']
print(f"Number of classes: {len(class_names)}")
print(f"Class names: {class_names}")

# Model setup
num_classes = len(class_names)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ViTForCancerClassification(num_classes).to(device)
print(model)

# Training setup
torch.manual_seed(42)
EPOCHS = 20
class_weights_tensor = torch.FloatTensor(class_weights).to(device)
loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

results = {
 'train_loss': [], 
 'train_acc': [],
 'test_loss': [],
 'test_acc': []
}

Data loaded successfully!
Number of classes: 32
Class names: ['Adrenocortical_carcinoma', 'Bladder_Urothelial_Carcinoma', 'Brain_Lower_Grade_Glioma', 'Breast_invasive_carcinoma', 'Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma', 'Cholangiocarcinoma', 'Colon_adenocarcinoma', 'Esophageal_carcinoma', 'Glioblastoma_multiforme', 'Head_and_Neck_squamous_cell_carcinoma', 'Kidney_Chromophobe', 'Kidney_renal_clear_cell_carcinoma', 'Kidney_renal_papillary_cell_carcinoma', 'Liver_hepatocellular_carcinoma', 'Lung_adenocarcinoma', 'Lung_squamous_cell_carcinoma', 'Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma', 'Mesothelioma', 'Ovarian_serous_cystadenocarcinoma', 'Pancreatic_adenocarcinoma', 'Pheochromocytoma_and_Paraganglioma', 'Prostate_adenocarcinoma', 'Rectum_adenocarcinoma', 'Sarcoma', 'Skin_Cutaneous_Melanoma', 'Stomach_adenocarcinoma', 'Testicular_Germ_Cell_Tumors', 'Thymoma', 'Thyroid_carcinoma', 'Uterine_Carcinosarcoma', 'Uterine_Corpus_Endometrial_Carcinoma', 'Uveal_Mel

In [None]:
import torch

# Define the checkpoint file (change to the correct path if necessary)
checkpoint_path = 'vit_cancer_model_state_dict_X.pth' # Replace 'X' with the last saved epoch number

# Load the saved model if it exists
if os.path.exists(checkpoint_path):
 print(f"Loading model from {checkpoint_path}")
 model.load_state_dict(torch.load(checkpoint_path))
 start_epoch = int(checkpoint_path.split('_')[-1].split('.')[0]) + 1
else:
 print("No checkpoint found, starting training from scratch.")
 start_epoch = 0

# Resume training
for epoch in range(start_epoch, EPOCHS):
 print(f"Epoch {epoch+1}/{EPOCHS}")
 train_loss, train_acc = 0, 0
 model.train()
 for batch, (X, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
 X, y = X.to(device), y.to(device)
 y_logits = model(X)
 y_pred_class = torch.argmax(torch.softmax(y_logits, dim=1), dim=1)
 loss = loss_fn(y_logits, y)
 train_acc += (y_pred_class == y).sum().item() / len(y)
 train_loss += loss.item()
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 train_loss /= len(train_dataloader)
 train_acc /= len(train_dataloader)
 
 results['train_loss'].append(train_loss)
 results['train_acc'].append(train_acc)
 
 model.eval()
 test_loss, test_acc = 0, 0
 with torch.inference_mode():
 for batch, (X, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
 X, y = X.to(device), y.to(device)
 
 test_logits = model(X)
 test_pred_labels = test_logits.argmax(dim=1)
 loss = loss_fn(test_logits, y)
 test_acc += (test_pred_labels == y).sum().item() / len(y)
 test_loss += loss.item()
 
 test_loss /= len(test_dataloader)
 test_acc /= len(test_dataloader)
 print(f'Training loss: {train_loss:.5f} acc: {train_acc:.5f} | Testing loss: {test_loss:.5f} acc: {test_acc:.5f}')
 
 results['test_loss'].append(test_loss)
 results['test_acc'].append(test_acc)
 
 # Save the model checkpoint after every epoch
 torch.save(model.state_dict(), f'vit_cancer_model_state_dict_{epoch}.pth')