Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,10 +7,22 @@ import nibabel as nib
|
|
7 |
from PIL import Image
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import streamlit as st
|
|
|
10 |
|
11 |
-
# 1. Function to extract zip files
|
12 |
-
def
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
zip_ref.extractall(extract_to)
|
15 |
|
16 |
# 2. Preprocess images
|
@@ -22,12 +34,12 @@ def preprocess_image(image_path):
|
|
22 |
image_data = nii_image.get_fdata()
|
23 |
image_tensor = torch.tensor(image_data).float()
|
24 |
if len(image_tensor.shape) == 3:
|
25 |
-
image_tensor = image_tensor.unsqueeze(0)
|
26 |
|
27 |
elif ext in ['.jpg', '.jpeg']:
|
28 |
img = Image.open(image_path).convert('RGB').resize((224, 224))
|
29 |
img_np = np.array(img)
|
30 |
-
image_tensor = torch.tensor(img_np).permute(2, 0, 1).float()
|
31 |
|
32 |
else:
|
33 |
raise ValueError(f"Unsupported format: {ext}")
|
@@ -42,10 +54,11 @@ def prepare_dataset(extracted_folder):
|
|
42 |
for disease_folder in ['alzheimers', 'parkinsons', 'ms']:
|
43 |
folder_path = os.path.join(extracted_folder, disease_folder)
|
44 |
label = {'alzheimers': 0, 'parkinsons': 1, 'ms': 2}[disease_folder]
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
49 |
return image_paths, labels
|
50 |
|
51 |
# 4. Custom Dataset
|
@@ -86,25 +99,29 @@ def fine_tune_model(train_loader):
|
|
86 |
# Streamlit UI
|
87 |
st.title("Fine-tune ViT on MRI Scans")
|
88 |
|
89 |
-
# Input for zip file
|
90 |
-
zip_file_1 = st.text_input("https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/archive%20(5).zip")
|
91 |
-
zip_file_2 = st.text_input("https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/MS.zip")
|
92 |
|
93 |
if st.button("Start Training"):
|
94 |
# Define an extraction directory
|
95 |
-
extraction_dir = '
|
96 |
os.makedirs(extraction_dir, exist_ok=True)
|
97 |
|
98 |
-
#
|
99 |
-
|
100 |
-
|
|
|
|
|
101 |
|
102 |
# Prepare dataset
|
|
|
103 |
image_paths, labels = prepare_dataset(extraction_dir)
|
104 |
dataset = CustomImageDataset(image_paths, labels)
|
105 |
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
106 |
|
107 |
# Fine-tune the model
|
|
|
108 |
final_loss = fine_tune_model(train_loader)
|
109 |
st.write(f"Training Complete with Final Loss: {final_loss}")
|
110 |
|
|
|
7 |
from PIL import Image
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import streamlit as st
|
10 |
+
import requests
|
11 |
|
12 |
+
# 1. Function to download and extract zip files
|
13 |
+
def download_and_extract_zip(zip_url, extract_to):
|
14 |
+
local_zip_path = os.path.join(extract_to, os.path.basename(zip_url))
|
15 |
+
os.makedirs(extract_to, exist_ok=True)
|
16 |
+
|
17 |
+
# Download the zip file
|
18 |
+
with requests.get(zip_url, stream=True) as r:
|
19 |
+
r.raise_for_status()
|
20 |
+
with open(local_zip_path, 'wb') as f:
|
21 |
+
for chunk in r.iter_content(chunk_size=8192):
|
22 |
+
f.write(chunk)
|
23 |
+
|
24 |
+
# Extract the zip file
|
25 |
+
with zipfile.ZipFile(local_zip_path, 'r') as zip_ref:
|
26 |
zip_ref.extractall(extract_to)
|
27 |
|
28 |
# 2. Preprocess images
|
|
|
34 |
image_data = nii_image.get_fdata()
|
35 |
image_tensor = torch.tensor(image_data).float()
|
36 |
if len(image_tensor.shape) == 3:
|
37 |
+
image_tensor = image_tensor.unsqueeze(0) # Add channel dimension for MRI
|
38 |
|
39 |
elif ext in ['.jpg', '.jpeg']:
|
40 |
img = Image.open(image_path).convert('RGB').resize((224, 224))
|
41 |
img_np = np.array(img)
|
42 |
+
image_tensor = torch.tensor(img_np).permute(2, 0, 1).float() # Change to [C, H, W]
|
43 |
|
44 |
else:
|
45 |
raise ValueError(f"Unsupported format: {ext}")
|
|
|
54 |
for disease_folder in ['alzheimers', 'parkinsons', 'ms']:
|
55 |
folder_path = os.path.join(extracted_folder, disease_folder)
|
56 |
label = {'alzheimers': 0, 'parkinsons': 1, 'ms': 2}[disease_folder]
|
57 |
+
if os.path.exists(folder_path):
|
58 |
+
for img_file in os.listdir(folder_path):
|
59 |
+
if img_file.endswith(('.nii', '.jpg', '.jpeg')):
|
60 |
+
image_paths.append(os.path.join(folder_path, img_file))
|
61 |
+
labels.append(label)
|
62 |
return image_paths, labels
|
63 |
|
64 |
# 4. Custom Dataset
|
|
|
99 |
# Streamlit UI
|
100 |
st.title("Fine-tune ViT on MRI Scans")
|
101 |
|
102 |
+
# Input for zip file URLs
|
103 |
+
zip_file_1 = st.text_input("Enter URL for the 1st zip file:", "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/archive%20(5).zip")
|
104 |
+
zip_file_2 = st.text_input("Enter URL for the 2nd zip file:", "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/MS.zip")
|
105 |
|
106 |
if st.button("Start Training"):
|
107 |
# Define an extraction directory
|
108 |
+
extraction_dir = 'extracted_files'
|
109 |
os.makedirs(extraction_dir, exist_ok=True)
|
110 |
|
111 |
+
# Download and extract both zip files
|
112 |
+
st.write("Downloading and extracting files...")
|
113 |
+
download_and_extract_zip(zip_file_1, extraction_dir)
|
114 |
+
download_and_extract_zip(zip_file_2, extraction_dir)
|
115 |
+
st.write("Extraction complete.")
|
116 |
|
117 |
# Prepare dataset
|
118 |
+
st.write("Preparing dataset...")
|
119 |
image_paths, labels = prepare_dataset(extraction_dir)
|
120 |
dataset = CustomImageDataset(image_paths, labels)
|
121 |
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
122 |
|
123 |
# Fine-tune the model
|
124 |
+
st.write("Training model...")
|
125 |
final_loss = fine_tune_model(train_loader)
|
126 |
st.write(f"Training Complete with Final Loss: {final_loss}")
|
127 |
|