Fine-tuned Vision Transformer for Alzheimer's Detection
This repository hosts a Vision Transformer (ViT) model fine-tuned on the OASIS MRI dataset for the classification of brain MRI images based on the progression of Alzheimer's disease. The model categorizes images into four classes: demented, very mild demented, mild demented, and non-demented.
Model Description
The Vision Transformer has been adapted to tackle the challenging task of medical image analysis by leveraging its powerful attention mechanisms that capture complex patterns in image data. It has been fine-tuned to classify MRI images into stages of Alzheimer's disease, demonstrating the model's applicability to medical diagnostics.
Dataset
The OASIS MRI dataset consists of 80,000 brain MRI images from 461 patients, formatted in Nifti (.nii) and converted to JPEG for model training. The images represent various stages of Alzheimer's disease as follows:
- Non-Demented
- Very Mild Demented
- Mild Demented
- Demented
This dataset conversion involved standardizing image formats for machine learning applications, ensuring that each image is suitable for deep learning models.
Preprocessing Techniques
During preprocessing:
- MRI scans were converted from Nifti format to JPEG to simplify handling and reduce storage requirements.
- Each image was resized to 128x128 pixels, ensuring uniformity across the dataset.
- Pixel values were normalized to a [0, 1] scale to facilitate model training.
How to Use This Model
You can use this model directly with a pipeline for image classification:
```python
import torch from transformers import ViTForImageClassification from PIL import Image import numpy as np from torchvision.transforms import Compose, Resize, ToTensor, Normalize
id2label = { 0: "Mild Dementia", 1: "Moderate Dementia", 2: "Non Demented", 3: "Very mild Dementia" }
import torch from transformers import ViTForImageClassification from PIL import Image import numpy as np from torchvision.transforms import Compose, Resize, ToTensor, Normalize import matplotlib.pyplot as plt
Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Load the model
model = ViTForImageClassification.from_pretrained('fawadkhan/ViT_FineTuned_on_ImagesOASIS') model.to(device) model.eval()
Define the image path
image_path = 'your image path.jpg' image = Image.open(image_path).convert("RGB")
Define the transformations
transform = Compose([ Resize((224, 224)), # or the original input size of your model ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard normalization for ImageNet ])
Preprocess the image
input_tensor = transform(image).unsqueeze(0) # Create a mini-batch as expected by the model input_tensor = input_tensor.to(device)
Predict
with torch.no_grad(): outputs = model(input_tensor) _, predicted = torch.max(outputs.logits, 1)
Retrieve the class name
predicted_class = id2label[predicted[0].item()] print("Predicted class:", predicted_class)
Plot the image and the prediction
plt.imshow(image) plt.title(f'Predicted class: {predicted_class}') plt.axis('off') # Turn off axis numbers and ticks plt.show()
```
Training Procedure
The model was trained using the AdamW optimizer with a learning rate of 5e-5 for 10 epochs, balancing the need for accuracy with the risk of overfitting.
Evaluation Results
Upon evaluation on a validation set, the model achieved an accuracy of 99%, showcasing its effectiveness in identifying different stages of Alzheimer's disease based on MRI scans.
- Downloads last month
- 67