Spaces:
Sleeping
Sleeping
import os | |
import zipfile | |
import numpy as np | |
import torch | |
from transformers import ViTForImageClassification, AdamW | |
import nibabel as nib | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader | |
import streamlit as st | |
import requests | |
import tempfile | |
# Function to download zip files from URL | |
def download_zip(url, download_path): | |
response = requests.get(url) | |
with open(download_path, 'wb') as file: | |
file.write(response.content) | |
# Function to extract zip files | |
def extract_zip(zip_file, extract_to): | |
with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
zip_ref.extractall(extract_to) | |
# Preprocess images | |
def preprocess_image(image_path): | |
ext = os.path.splitext(image_path)[-1].lower() | |
if ext in ['.nii', '.nii.gz']: | |
nii_image = nib.load(image_path) | |
image_data = nii_image.get_fdata() | |
image_tensor = torch.tensor(image_data).float() | |
if len(image_tensor.shape) == 3: | |
image_tensor = image_tensor.unsqueeze(0) | |
elif ext in ['.jpg', '.jpeg']: | |
img = Image.open(image_path).convert('RGB').resize((224, 224)) | |
img_np = np.array(img) | |
image_tensor = torch.tensor(img_np).permute(2, 0, 1).float() | |
else: | |
raise ValueError(f"Unsupported format: {ext}") | |
image_tensor /= 255.0 # Normalize to [0, 1] | |
return image_tensor | |
# Prepare dataset | |
def prepare_dataset(extracted_folder): | |
image_paths = [] | |
labels = [] | |
# Define the paths for each disease dataset | |
datasets = { | |
'alzheimer_datasets': 0, | |
'parkinson_datasets': 1, | |
'MSjpg': 2 | |
} | |
for disease_folder, label in datasets.items(): | |
folder_path = os.path.join(extracted_folder, 'neuroniiimages', disease_folder) | |
for img_file in os.listdir(folder_path): | |
if img_file.endswith(('.nii', '.jpg', '.jpeg')): | |
image_paths.append(os.path.join(folder_path, img_file)) | |
labels.append(label) | |
return image_paths, labels | |
# Custom Dataset class | |
class CustomImageDataset(Dataset): | |
def __init__(self, image_paths, labels): | |
self.image_paths = image_paths | |
self.labels = labels | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
image = preprocess_image(self.image_paths[idx]) | |
label = self.labels[idx] | |
return image, label | |
# Training function | |
def fine_tune_model(train_loader): | |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3) | |
model.train() | |
optimizer = AdamW(model.parameters(), lr=1e-4) | |
criterion = torch.nn.CrossEntropyLoss() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
for epoch in range(10): | |
running_loss = 0.0 | |
for images, labels in train_loader: | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(pixel_values=images).logits | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
return running_loss / len(train_loader) | |
# Streamlit UI for Fine-tuning | |
st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases") | |
zip_url = "https://huggingface.co./spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/neuroniiimages.zip" | |
if st.button("Start Training"): | |
extraction_dir = "extracted_files" | |
os.makedirs(extraction_dir, exist_ok=True) | |
# Download the zip file to a temporary file | |
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file: | |
download_zip(zip_url, tmp_file.name) | |
# Extract the zip file | |
extract_zip(tmp_file.name, extraction_dir) | |
# Prepare dataset | |
image_paths, labels = prepare_dataset(extraction_dir) | |
dataset = CustomImageDataset(image_paths, labels) | |
if len(image_paths) == 0: | |
st.error("No images found in the specified directory. Please check the folder structure.") | |
else: | |
train_loader = DataLoader(dataset, batch_size=32, shuffle=True) | |
# Fine-tune the model | |
final_loss = fine_tune_model(train_loader) | |
st.write(f"Training Complete with Final Loss: {final_loss}") | |