Tanusree88 commited on
Commit
98e44a5
·
verified ·
1 Parent(s): aa4d61b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -41
app.py CHANGED
@@ -7,39 +7,27 @@ import nibabel as nib
7
  from PIL import Image
8
  from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
10
- import requests
11
 
12
- # 1. Function to download and extract zip files
13
- def download_and_extract_zip(zip_url, extract_to):
14
- local_zip_path = os.path.join(extract_to, os.path.basename(zip_url))
15
- os.makedirs(extract_to, exist_ok=True)
16
-
17
- # Download the zip file
18
- with requests.get(zip_url, stream=True) as r:
19
- r.raise_for_status()
20
- with open(local_zip_path, 'wb') as f:
21
- for chunk in r.iter_content(chunk_size=8192):
22
- f.write(chunk)
23
-
24
- # Extract the zip file
25
- with zipfile.ZipFile(local_zip_path, 'r') as zip_ref:
26
  zip_ref.extractall(extract_to)
27
 
28
- # 2. Preprocess images
29
  def preprocess_image(image_path):
30
  ext = os.path.splitext(image_path)[-1].lower()
31
 
32
- if ext == '.nii' or ext == '.nii.gz':
33
  nii_image = nib.load(image_path)
34
  image_data = nii_image.get_fdata()
35
  image_tensor = torch.tensor(image_data).float()
36
  if len(image_tensor.shape) == 3:
37
- image_tensor = image_tensor.unsqueeze(0) # Add channel dimension for MRI
38
 
39
  elif ext in ['.jpg', '.jpeg']:
40
  img = Image.open(image_path).convert('RGB').resize((224, 224))
41
  img_np = np.array(img)
42
- image_tensor = torch.tensor(img_np).permute(2, 0, 1).float() # Change to [C, H, W]
43
 
44
  else:
45
  raise ValueError(f"Unsupported format: {ext}")
@@ -47,21 +35,20 @@ def preprocess_image(image_path):
47
  image_tensor /= 255.0 # Normalize to [0, 1]
48
  return image_tensor
49
 
50
- # 3. Label images
51
  def prepare_dataset(extracted_folder):
52
  image_paths = []
53
  labels = []
54
  for disease_folder in ['alzheimers', 'parkinsons', 'ms']:
55
  folder_path = os.path.join(extracted_folder, disease_folder)
56
  label = {'alzheimers': 0, 'parkinsons': 1, 'ms': 2}[disease_folder]
57
- if os.path.exists(folder_path):
58
- for img_file in os.listdir(folder_path):
59
- if img_file.endswith(('.nii', '.jpg', '.jpeg')):
60
- image_paths.append(os.path.join(folder_path, img_file))
61
- labels.append(label)
62
  return image_paths, labels
63
 
64
- # 4. Custom Dataset
65
  class CustomImageDataset(Dataset):
66
  def __init__(self, image_paths, labels):
67
  self.image_paths = image_paths
@@ -75,7 +62,7 @@ class CustomImageDataset(Dataset):
75
  label = self.labels[idx]
76
  return image, label
77
 
78
- # 5. Training function
79
  def fine_tune_model(train_loader):
80
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
81
  model.train()
@@ -96,32 +83,28 @@ def fine_tune_model(train_loader):
96
  running_loss += loss.item()
97
  return running_loss / len(train_loader)
98
 
99
- # Streamlit UI
100
- st.title("Fine-tune ViT on MRI Scans")
101
 
102
- # Input for zip file URLs
103
- zip_file_1 = st.text_input("Enter URL for the 1st zip file:", "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/archive%20(5).zip")
104
- zip_file_2 = st.text_input("Enter URL for the 2nd zip file:", "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/MS.zip")
105
 
106
  if st.button("Start Training"):
107
- # Define an extraction directory
108
- extraction_dir = 'extracted_files'
109
  os.makedirs(extraction_dir, exist_ok=True)
110
 
111
- # Download and extract both zip files
112
- st.write("Downloading and extracting files...")
113
- download_and_extract_zip(zip_file_1, extraction_dir)
114
- download_and_extract_zip(zip_file_2, extraction_dir)
115
- st.write("Extraction complete.")
116
 
117
  # Prepare dataset
118
- st.write("Preparing dataset...")
119
  image_paths, labels = prepare_dataset(extraction_dir)
120
  dataset = CustomImageDataset(image_paths, labels)
121
  train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
122
 
123
  # Fine-tune the model
124
- st.write("Training model...")
125
  final_loss = fine_tune_model(train_loader)
126
  st.write(f"Training Complete with Final Loss: {final_loss}")
127
 
 
 
 
7
  from PIL import Image
8
  from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
 
10
 
11
+ # Function to extract zip files
12
+ def extract_zip(zip_file, extract_to):
13
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
 
 
 
 
 
 
 
 
 
 
 
14
  zip_ref.extractall(extract_to)
15
 
16
+ # Preprocess images
17
  def preprocess_image(image_path):
18
  ext = os.path.splitext(image_path)[-1].lower()
19
 
20
+ if ext in ['.nii', '.nii.gz']:
21
  nii_image = nib.load(image_path)
22
  image_data = nii_image.get_fdata()
23
  image_tensor = torch.tensor(image_data).float()
24
  if len(image_tensor.shape) == 3:
25
+ image_tensor = image_tensor.unsqueeze(0)
26
 
27
  elif ext in ['.jpg', '.jpeg']:
28
  img = Image.open(image_path).convert('RGB').resize((224, 224))
29
  img_np = np.array(img)
30
+ image_tensor = torch.tensor(img_np).permute(2, 0, 1).float()
31
 
32
  else:
33
  raise ValueError(f"Unsupported format: {ext}")
 
35
  image_tensor /= 255.0 # Normalize to [0, 1]
36
  return image_tensor
37
 
38
+ # Prepare dataset
39
  def prepare_dataset(extracted_folder):
40
  image_paths = []
41
  labels = []
42
  for disease_folder in ['alzheimers', 'parkinsons', 'ms']:
43
  folder_path = os.path.join(extracted_folder, disease_folder)
44
  label = {'alzheimers': 0, 'parkinsons': 1, 'ms': 2}[disease_folder]
45
+ for img_file in os.listdir(folder_path):
46
+ if img_file.endswith(('.nii', '.jpg', '.jpeg')):
47
+ image_paths.append(os.path.join(folder_path, img_file))
48
+ labels.append(label)
 
49
  return image_paths, labels
50
 
51
+ # Custom Dataset class
52
  class CustomImageDataset(Dataset):
53
  def __init__(self, image_paths, labels):
54
  self.image_paths = image_paths
 
62
  label = self.labels[idx]
63
  return image, label
64
 
65
+ # Training function
66
  def fine_tune_model(train_loader):
67
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
68
  model.train()
 
83
  running_loss += loss.item()
84
  return running_loss / len(train_loader)
85
 
86
+ # Streamlit UI for Fine-tuning
87
+ st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
88
 
89
+ zip_file_1 = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/archive%20(5).zip"
90
+ zip_file_2 = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/MS.zip"
 
91
 
92
  if st.button("Start Training"):
93
+ extraction_dir = "extracted_files"
 
94
  os.makedirs(extraction_dir, exist_ok=True)
95
 
96
+ # Extract both zip files
97
+ extract_zip(zip_file_1, extraction_dir)
98
+ extract_zip(zip_file_2, extraction_dir)
 
 
99
 
100
  # Prepare dataset
 
101
  image_paths, labels = prepare_dataset(extraction_dir)
102
  dataset = CustomImageDataset(image_paths, labels)
103
  train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
104
 
105
  # Fine-tune the model
 
106
  final_loss = fine_tune_model(train_loader)
107
  st.write(f"Training Complete with Final Loss: {final_loss}")
108
 
109
+
110
+