Tanusree88 commited on
Commit
bb90c09
·
verified ·
1 Parent(s): 6ca8ca8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -23
app.py CHANGED
@@ -2,22 +2,24 @@ import os
2
  import zipfile
3
  import numpy as np
4
  import torch
5
- import requests
6
  from transformers import ViTForImageClassification, AdamW
7
  import nibabel as nib
8
  from PIL import Image
9
  from torch.utils.data import Dataset, DataLoader
10
  import streamlit as st
11
- from io import BytesIO
 
12
 
13
- # Function to download and extract zip file from URL
14
- def extract_zip_from_url(url, extract_to):
15
  response = requests.get(url)
16
- if response.status_code == 200:
17
- with zipfile.ZipFile(BytesIO(response.content)) as zip_ref:
18
- zip_ref.extractall(extract_to)
19
- else:
20
- raise ValueError(f"Unable to download zip file: {url}")
 
 
21
 
22
  # Preprocess images
23
  def preprocess_image(image_path):
@@ -45,13 +47,21 @@ def preprocess_image(image_path):
45
  def prepare_dataset(extracted_folder):
46
  image_paths = []
47
  labels = []
48
- for disease_folder in ['alzheimers_dataset', 'parkinsons_dataset', 'MSjpg']:
49
- folder_path = os.path.join(extracted_folder, disease_folder)
50
- label = {'alzheimers_dataset': 0, 'parkinsons_dataset': 1, 'MSjpg': 2}[disease_folder]
 
 
 
 
 
 
 
51
  for img_file in os.listdir(folder_path):
52
  if img_file.endswith(('.nii', '.jpg', '.jpeg')):
53
  image_paths.append(os.path.join(folder_path, img_file))
54
  labels.append(label)
 
55
  return image_paths, labels
56
 
57
  # Custom Dataset class
@@ -92,23 +102,29 @@ def fine_tune_model(train_loader):
92
  # Streamlit UI for Fine-tuning
93
  st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
94
 
95
- # Input zip file URL
96
- zip_file_url = st.text_input("https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/neuroniiimages.zip")
97
  if st.button("Start Training"):
98
- extraction_dir = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/extracttedfiles"
99
  os.makedirs(extraction_dir, exist_ok=True)
100
 
101
- # Extract the zip file from URL
102
- extract_zip_from_url(zip_file_url, extraction_dir)
 
 
 
 
103
 
104
  # Prepare dataset
105
  image_paths, labels = prepare_dataset(extraction_dir)
106
  dataset = CustomImageDataset(image_paths, labels)
107
- train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
108
-
109
- # Fine-tune the model
110
- final_loss = fine_tune_model(train_loader)
111
- st.write(f"Training Complete with Final Loss: {final_loss}")
112
-
113
 
 
 
 
114
 
 
2
  import zipfile
3
  import numpy as np
4
  import torch
 
5
  from transformers import ViTForImageClassification, AdamW
6
  import nibabel as nib
7
  from PIL import Image
8
  from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
10
+ import requests
11
+ import tempfile
12
 
13
+ # Function to download zip files from URL
14
+ def download_zip(url, download_path):
15
  response = requests.get(url)
16
+ with open(download_path, 'wb') as file:
17
+ file.write(response.content)
18
+
19
+ # Function to extract zip files
20
+ def extract_zip(zip_file, extract_to):
21
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
22
+ zip_ref.extractall(extract_to)
23
 
24
  # Preprocess images
25
  def preprocess_image(image_path):
 
47
  def prepare_dataset(extracted_folder):
48
  image_paths = []
49
  labels = []
50
+
51
+ # Define the paths for each disease dataset
52
+ datasets = {
53
+ 'alzheimer_datasets': 0,
54
+ 'parkinson_datasets': 1,
55
+ 'MSjpg': 2
56
+ }
57
+
58
+ for disease_folder, label in datasets.items():
59
+ folder_path = os.path.join(extracted_folder, 'neuroniiimages', disease_folder)
60
  for img_file in os.listdir(folder_path):
61
  if img_file.endswith(('.nii', '.jpg', '.jpeg')):
62
  image_paths.append(os.path.join(folder_path, img_file))
63
  labels.append(label)
64
+
65
  return image_paths, labels
66
 
67
  # Custom Dataset class
 
102
  # Streamlit UI for Fine-tuning
103
  st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
104
 
105
+ zip_url = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/neuroniiimages.zip"
106
+
107
  if st.button("Start Training"):
108
+ extraction_dir = "extracted_files"
109
  os.makedirs(extraction_dir, exist_ok=True)
110
 
111
+ # Download the zip file to a temporary file
112
+ with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file:
113
+ download_zip(zip_url, tmp_file.name)
114
+
115
+ # Extract the zip file
116
+ extract_zip(tmp_file.name, extraction_dir)
117
 
118
  # Prepare dataset
119
  image_paths, labels = prepare_dataset(extraction_dir)
120
  dataset = CustomImageDataset(image_paths, labels)
121
+
122
+ if len(image_paths) == 0:
123
+ st.error("No images found in the specified directory. Please check the folder structure.")
124
+ else:
125
+ train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
 
126
 
127
+ # Fine-tune the model
128
+ final_loss = fine_tune_model(train_loader)
129
+ st.write(f"Training Complete with Final Loss: {final_loss}")
130