Tanusree88 commited on
Commit
c8604b9
·
verified ·
1 Parent(s): c167ae6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import zipfile
 
3
  import numpy as np
4
  import torch
5
  from transformers import ViTForImageClassification, AdamW
@@ -8,6 +9,12 @@ from PIL import Image
8
  from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
10
 
 
 
 
 
 
 
11
  # Function to extract zip file
12
  def extract_zip(zip_file, extract_to):
13
  with zipfile.ZipFile(zip_file, 'r') as zip_ref:
@@ -86,14 +93,115 @@ def fine_tune_model(train_loader):
86
  # Streamlit UI for Fine-tuning
87
  st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
88
 
89
- zip_file = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/neuroniiimages.zip"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  if st.button("Start Training"):
92
  extraction_dir = "extracted_files"
 
93
  os.makedirs(extraction_dir, exist_ok=True)
94
 
 
 
 
 
95
  # Extract the zip file
96
- extract_zip(zip_file, extraction_dir)
 
97
 
98
  # Prepare dataset
99
  image_paths, labels = prepare_dataset(extraction_dir)
@@ -101,8 +209,32 @@ if st.button("Start Training"):
101
  train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
102
 
103
  # Fine-tune the model
 
104
  final_loss = fine_tune_model(train_loader)
105
  st.write(f"Training Complete with Final Loss: {final_loss}")
 
 
 
 
 
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
 
1
  import os
2
  import zipfile
3
+ import requests
4
  import numpy as np
5
  import torch
6
  from transformers import ViTForImageClassification, AdamW
 
9
  from torch.utils.data import Dataset, DataLoader
10
  import streamlit as st
11
 
12
+ # Function to download the zip file from the URL
13
+ def download_zip(url, save_path):
14
+ response = requests.get(url)
15
+ with open(save_path, 'wb') as f:
16
+ f.write(response.content)
17
+
18
  # Function to extract zip file
19
  def extract_zip(zip_file, extract_to):
20
  with zipfile.ZipFile(zip_file, 'r') as zip_ref:
 
93
  # Streamlit UI for Fine-tuning
94
  st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
95
 
96
+ zip_file_url = "import os
97
+ import zipfile
98
+ import requests
99
+ import numpy as np
100
+ import torch
101
+ from transformers import ViTForImageClassification, AdamW
102
+ import nibabel as nib
103
+ from PIL import Image
104
+ from torch.utils.data import Dataset, DataLoader
105
+ import streamlit as st
106
+
107
+ # Function to download the zip file from the URL
108
+ def download_zip(url, save_path):
109
+ response = requests.get(url)
110
+ with open(save_path, 'wb') as f:
111
+ f.write(response.content)
112
+
113
+ # Function to extract zip file
114
+ def extract_zip(zip_file, extract_to):
115
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
116
+ zip_ref.extractall(extract_to)
117
+
118
+ # Preprocess images
119
+ def preprocess_image(image_path):
120
+ ext = os.path.splitext(image_path)[-1].lower()
121
+
122
+ if ext in ['.nii', '.nii.gz']:
123
+ nii_image = nib.load(image_path)
124
+ image_data = nii_image.get_fdata()
125
+ image_tensor = torch.tensor(image_data).float()
126
+ if len(image_tensor.shape) == 3:
127
+ image_tensor = image_tensor.unsqueeze(0)
128
+
129
+ elif ext in ['.jpg', '.jpeg']:
130
+ img = Image.open(image_path).convert('RGB').resize((224, 224))
131
+ img_np = np.array(img)
132
+ image_tensor = torch.tensor(img_np).permute(2, 0, 1).float()
133
+
134
+ else:
135
+ raise ValueError(f"Unsupported format: {ext}")
136
+
137
+ image_tensor /= 255.0 # Normalize to [0, 1]
138
+ return image_tensor
139
+
140
+ # Prepare dataset
141
+ def prepare_dataset(extracted_folder):
142
+ image_paths = []
143
+ labels = []
144
+ for disease_folder in ['alzheimers_dataset', 'parkinsons_dataset', 'ms']:
145
+ folder_path = os.path.join(extracted_folder, disease_folder)
146
+ label = {'alzheimers_dataset': 0, 'parkinsons_dataset': 1, 'ms': 2}[disease_folder]
147
+ for img_file in os.listdir(folder_path):
148
+ if img_file.endswith(('.nii', '.jpg', '.jpeg')):
149
+ image_paths.append(os.path.join(folder_path, img_file))
150
+ labels.append(label)
151
+ return image_paths, labels
152
+
153
+ # Custom Dataset class
154
+ class CustomImageDataset(Dataset):
155
+ def __init__(self, image_paths, labels):
156
+ self.image_paths = image_paths
157
+ self.labels = labels
158
+
159
+ def __len__(self):
160
+ return len(self.image_paths)
161
+
162
+ def __getitem__(self, idx):
163
+ image = preprocess_image(self.image_paths[idx])
164
+ label = self.labels[idx]
165
+ return image, label
166
+
167
+ # Training function
168
+ def fine_tune_model(train_loader):
169
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
170
+ model.train()
171
+ optimizer = AdamW(model.parameters(), lr=1e-4)
172
+ criterion = torch.nn.CrossEntropyLoss()
173
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
174
+ model.to(device)
175
+
176
+ for epoch in range(10):
177
+ running_loss = 0.0
178
+ for images, labels in train_loader:
179
+ images, labels = images.to(device), labels.to(device)
180
+ optimizer.zero_grad()
181
+ outputs = model(pixel_values=images).logits
182
+ loss = criterion(outputs, labels)
183
+ loss.backward()
184
+ optimizer.step()
185
+ running_loss += loss.item()
186
+ return running_loss / len(train_loader)
187
+
188
+ # Streamlit UI for Fine-tuning
189
+ st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
190
+
191
+ zip_file_url = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/archive%20(5).zip"
192
 
193
  if st.button("Start Training"):
194
  extraction_dir = "extracted_files"
195
+ zip_file_path = "archive_5.zip"
196
  os.makedirs(extraction_dir, exist_ok=True)
197
 
198
+ # Download the zip file
199
+ st.write("Downloading the zip file...")
200
+ download_zip(zip_file_url, zip_file_path)
201
+
202
  # Extract the zip file
203
+ st.write("Extracting files...")
204
+ extract_zip(zip_file_path, extraction_dir)
205
 
206
  # Prepare dataset
207
  image_paths, labels = prepare_dataset(extraction_dir)
 
209
  train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
210
 
211
  # Fine-tune the model
212
+ st.write("Fine-tuning the model...")
213
  final_loss = fine_tune_model(train_loader)
214
  st.write(f"Training Complete with Final Loss: {final_loss}")
215
+ "
216
+
217
+ if st.button("Start Training"):
218
+ extraction_dir = "extracted_files"
219
+ zip_file_path = "archive_5.zip"
220
+ os.makedirs(extraction_dir, exist_ok=True)
221
 
222
+ # Download the zip file
223
+ st.write("Downloading the zip file...")
224
+ download_zip(zip_file_url, zip_file_path)
225
+
226
+ # Extract the zip file
227
+ st.write("Extracting files...")
228
+ extract_zip(zip_file_path, extraction_dir)
229
+
230
+ # Prepare dataset
231
+ image_paths, labels = prepare_dataset(extraction_dir)
232
+ dataset = CustomImageDataset(image_paths, labels)
233
+ train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
234
+
235
+ # Fine-tune the model
236
+ st.write("Fine-tuning the model...")
237
+ final_loss = fine_tune_model(train_loader)
238
+ st.write(f"Training Complete with Final Loss: {final_loss}")
239
 
240