Fine-tuned Vision Transformer for Alzheimer's Detection

This repository hosts a Vision Transformer (ViT) model fine-tuned on the OASIS MRI dataset for the classification of brain MRI images based on the progression of Alzheimer's disease. The model categorizes images into four classes: demented, very mild demented, mild demented, and non-demented.

Model Description

The Vision Transformer has been adapted to tackle the challenging task of medical image analysis by leveraging its powerful attention mechanisms that capture complex patterns in image data. It has been fine-tuned to classify MRI images into stages of Alzheimer's disease, demonstrating the model's applicability to medical diagnostics.

Dataset

The OASIS MRI dataset consists of 80,000 brain MRI images from 461 patients, formatted in Nifti (.nii) and converted to JPEG for model training. The images represent various stages of Alzheimer's disease as follows:

  • Non-Demented
  • Very Mild Demented
  • Mild Demented
  • Demented

This dataset conversion involved standardizing image formats for machine learning applications, ensuring that each image is suitable for deep learning models.

Preprocessing Techniques

During preprocessing:

  • MRI scans were converted from Nifti format to JPEG to simplify handling and reduce storage requirements.
  • Each image was resized to 128x128 pixels, ensuring uniformity across the dataset.
  • Pixel values were normalized to a [0, 1] scale to facilitate model training.

How to Use This Model

You can use this model directly with a pipeline for image classification:

```python

import torch from transformers import ViTForImageClassification from PIL import Image import numpy as np from torchvision.transforms import Compose, Resize, ToTensor, Normalize

id2label = { 0: "Mild Dementia", 1: "Moderate Dementia", 2: "Non Demented", 3: "Very mild Dementia" }

import torch from transformers import ViTForImageClassification from PIL import Image import numpy as np from torchvision.transforms import Compose, Resize, ToTensor, Normalize import matplotlib.pyplot as plt

Set the device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load the model

model = ViTForImageClassification.from_pretrained('fawadkhan/ViT_FineTuned_on_ImagesOASIS') model.to(device) model.eval()

Define the image path

image_path = 'your image path.jpg' image = Image.open(image_path).convert("RGB")

Define the transformations

transform = Compose([ Resize((224, 224)), # or the original input size of your model ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard normalization for ImageNet ])

Preprocess the image

input_tensor = transform(image).unsqueeze(0) # Create a mini-batch as expected by the model input_tensor = input_tensor.to(device)

Predict

with torch.no_grad(): outputs = model(input_tensor) _, predicted = torch.max(outputs.logits, 1)

Retrieve the class name

predicted_class = id2label[predicted[0].item()] print("Predicted class:", predicted_class)

Plot the image and the prediction

plt.imshow(image) plt.title(f'Predicted class: {predicted_class}') plt.axis('off') # Turn off axis numbers and ticks plt.show()

```

Training Procedure

The model was trained using the AdamW optimizer with a learning rate of 5e-5 for 10 epochs, balancing the need for accuracy with the risk of overfitting.

Evaluation Results

Upon evaluation on a validation set, the model achieved an accuracy of 99%, showcasing its effectiveness in identifying different stages of Alzheimer's disease based on MRI scans.

Downloads last month
67
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.