Tanusree88 commited on
Commit
21c7368
·
verified ·
1 Parent(s): 3c12f92

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import torch
4
+ from torch.utils.data import DataLoader, TensorDataset
5
+ from preprocessing import preprocess_image
6
+ from train import fine_tune_model
7
+
8
+ # Define a function to load and preprocess images
9
+ def load_images(image_paths):
10
+ images = []
11
+ labels = [] # Assuming you have labels, adjust as needed
12
+
13
+ for path in image_paths:
14
+ # Preprocess each image and append to the list
15
+ image_tensor = preprocess_image(path)
16
+ images.append(image_tensor)
17
+
18
+ # Append corresponding label (this is just an example; adjust as needed)
19
+ label = ... # Replace with logic to get the label for the image
20
+ labels.append(label)
21
+
22
+ # Stack images into a single tensor and create a DataLoader
23
+ images_tensor = torch.stack(images) # Stack into a single tensor
24
+ labels_tensor = torch.tensor(labels) # Convert labels to tensor
25
+ dataset = TensorDataset(images_tensor, labels_tensor)
26
+ return DataLoader(dataset, batch_size=16, shuffle=True) # Adjust batch_size as needed
27
+
28
+ # Streamlit UI
29
+ st.title("MRI Image Fine-Tuning with ViT")
30
+
31
+ # Upload images or provide a directory path
32
+ image_directory = st.text_input("Enter the directory of images:")
33
+
34
+ if st.button("Load Images"):
35
+ # Get all image paths from the directory
36
+ if image_directory and os.path.isdir(image_directory):
37
+ image_paths = [os.path.join(image_directory, f) for f in os.listdir(image_directory) if f.endswith(('.nii', '.jpg', '.jpeg'))]
38
+ st.success(f"Loaded {len(image_paths)} images.")
39
+
40
+ # Preprocess images and create DataLoader
41
+ train_loader = load_images(image_paths)
42
+
43
+ # Button to start training
44
+ if st.button("Start Training"):
45
+ final_loss = fine_tune_model(train_loader) # Trigger fine-tuning
46
+ st.write(f"Training complete with final loss: {final_loss}")
47
+ else:
48
+ st.error("Please enter a valid directory.")