Tanusree88 commited on
Commit
c21163c
·
verified ·
1 Parent(s): b2e756b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -3,7 +3,6 @@ import zipfile
3
  import numpy as np
4
  import torch
5
  from transformers import ViTForImageClassification, AdamW
6
- import nibabel as nib
7
  from PIL import Image
8
  from torch.utils.data import Dataset, DataLoader
9
  import streamlit as st
@@ -17,13 +16,14 @@ def extract_zip(zip_file, extract_to):
17
  def preprocess_image(image_path):
18
  ext = os.path.splitext(image_path)[-1].lower()
19
 
20
- if ext in ['.nii', '.nii.gz']:
21
- nii_image = nib.load(image_path)
22
- image_data = nii_image.get_fdata()
23
  image_tensor = torch.tensor(image_data).float()
24
- if len(image_tensor.shape) == 3:
25
- image_tensor = image_tensor.unsqueeze(0)
26
-
 
 
27
  elif ext in ['.jpg', '.jpeg']:
28
  img = Image.open(image_path).convert('RGB').resize((224, 224))
29
  img_np = np.array(img)
@@ -52,16 +52,18 @@ def prepare_dataset(extracted_folder):
52
  # Check if the subfolder exists
53
  if not os.path.exists(folder_path):
54
  print(f"Folder not found: {folder_path}")
55
- continue # Skip this folder if it's not foun
 
56
  label = {'alzheimers_dataset': 0, 'parkinsons_dataset': 1, 'MSjpg': 2}[disease_folder]
57
 
58
  for img_file in os.listdir(folder_path):
59
- if img_file.endswith(('.nii', '.jpg', '.jpeg')):
60
  image_paths.append(os.path.join(folder_path, img_file))
61
  labels.append(label)
62
  else:
63
- print(f"Unsuported file:{img_file}")
64
- print(f"Total images loaded :{len(image_paths)}")
 
65
  return image_paths, labels
66
 
67
  # Custom Dataset class
 
3
  import numpy as np
4
  import torch
5
  from transformers import ViTForImageClassification, AdamW
 
6
  from PIL import Image
7
  from torch.utils.data import Dataset, DataLoader
8
  import streamlit as st
 
16
  def preprocess_image(image_path):
17
  ext = os.path.splitext(image_path)[-1].lower()
18
 
19
+ if ext in ['.npy']:
20
+ image_data = np.load(image_path)
 
21
  image_tensor = torch.tensor(image_data).float()
22
+ if len(image_tensor.shape) == 2: # If the image is 2D (grayscale)
23
+ image_tensor = image_tensor.unsqueeze(0) # Add channel dimension
24
+ elif len(image_tensor.shape) == 3: # If the image is 3D (height, width, channels)
25
+ image_tensor = image_tensor.permute(2, 0, 1).float() # Change to (C, H, W)
26
+
27
  elif ext in ['.jpg', '.jpeg']:
28
  img = Image.open(image_path).convert('RGB').resize((224, 224))
29
  img_np = np.array(img)
 
52
  # Check if the subfolder exists
53
  if not os.path.exists(folder_path):
54
  print(f"Folder not found: {folder_path}")
55
+ continue # Skip this folder if it's not found
56
+
57
  label = {'alzheimers_dataset': 0, 'parkinsons_dataset': 1, 'MSjpg': 2}[disease_folder]
58
 
59
  for img_file in os.listdir(folder_path):
60
+ if img_file.endswith(('.npy', '.jpg', '.jpeg')):
61
  image_paths.append(os.path.join(folder_path, img_file))
62
  labels.append(label)
63
  else:
64
+ print(f"Unsupported file: {img_file}")
65
+
66
+ print(f"Total images loaded: {len(image_paths)}")
67
  return image_paths, labels
68
 
69
  # Custom Dataset class