MingGatsby commited on
Commit
c4cb2d6
1 Parent(s): 7821fcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -98
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import io
4
  import torch
5
- # import shutil
6
  import numpy as np
7
  import streamlit as st
8
 
@@ -20,7 +20,6 @@ from monai.transforms import (
20
  EnsureChannelFirst,
21
  AsDiscrete,
22
  Compose,
23
- LoadImage,
24
  RandFlip,
25
  RandRotate,
26
  RandZoom,
@@ -105,7 +104,6 @@ WINDOW_WIDTH_MAX = 3000
105
  # Evaluation Transforms
106
  eval_transforms = Compose(
107
  [
108
- # LoadImage(image_only=True),
109
  AsChannelFirst(),
110
  ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True),
111
  Resize(spatial_size=SPATIAL_SIZE)
@@ -115,7 +113,6 @@ eval_transforms = Compose(
115
  # CAM Transforms
116
  cam_transforms = Compose(
117
  [
118
- # LoadImage(image_only=True),
119
  AsChannelFirst(),
120
  Resize(spatial_size=SPATIAL_SIZE)
121
  ]
@@ -124,7 +121,6 @@ cam_transforms = Compose(
124
  # Original Transforms
125
  original_transforms = Compose(
126
  [
127
- # LoadImage(image_only=True),
128
  AsChannelFirst()
129
  ]
130
  )
@@ -135,18 +131,26 @@ def image_to_bytes(image):
135
  image.save(byte_stream, format='PNG')
136
  return byte_stream.getvalue()
137
 
138
- # if os.path.exists("tempDir"):
139
- # shutil.rmtree(os.path.join("tempDir"))
 
 
 
140
 
141
- # def create_dir(dirname: str):
142
- # if not os.path.exists(dirname):
143
- # os.makedirs(dirname, exist_ok=True)
 
 
 
144
 
145
- # create_dir("CT_tempDir")
146
- # create_dir("MRI_tempDir")
147
 
148
- # # Get the current working directory
149
- # current_directory = os.getcwd()
 
 
150
 
151
  set_determinism(seed=SEED)
152
  torch.manual_seed(SEED)
@@ -187,11 +191,6 @@ CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WI
187
 
188
  uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"])
189
  if uploaded_mri_file is not None:
190
- # To check file details
191
- file_details = {"FileName": uploaded_mri_file.name, "FileType": uploaded_mri_file.type, "FileSize": uploaded_mri_file.size}
192
- st.write(file_details)
193
-
194
- import pydicom
195
  # Read DICOM file into NumPy array
196
  dicom_data = pydicom.dcmread(uploaded_mri_file)
197
  dicom_array = dicom_data.pixel_array
@@ -202,14 +201,13 @@ if uploaded_mri_file is not None:
202
  # Then add a channel dimension
203
  dicom_array = dicom_array[:, :, np.newaxis]
204
 
205
- # Check the shape and dtype of dicom_array
206
- st.write(f"Shape of dicom_array: {dicom_array.shape}")
207
- st.write(f"Data type of dicom_array: {dicom_array.dtype}")
208
 
209
  transformed_array = eval_transforms(dicom_array)
210
 
211
  # Convert to PyTorch tensor and move to device
212
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
213
  image_tensor = transformed_array.clone().detach().unsqueeze(0).to(device)
214
 
215
  # Predict
@@ -226,11 +224,11 @@ if uploaded_mri_file is not None:
226
 
227
  # Load the original DICOM image for download
228
  download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device)
229
- download_image = download_image_tensor.squeeze()
230
 
231
  # Transform the download image and apply windowing
232
- transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
233
- windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
234
 
235
  # Streamlit button to trigger image download
236
  image_data = image_to_bytes(Image.fromarray(windowed_download_image))
@@ -243,11 +241,11 @@ if uploaded_mri_file is not None:
243
 
244
  # Load the original DICOM image for display
245
  display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device)
246
- display_image = display_image_tensor.squeeze()
247
 
248
  # Transform the image and apply windowing
249
- transformed_image = DICOM_Utils.transform_image_for_display(display_image)
250
- windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
251
  st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True)
252
 
253
  # Expand to three channels
@@ -270,72 +268,78 @@ if uploaded_mri_file is not None:
270
  visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
271
  st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)
272
 
273
- # uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"])
274
- # if uploaded_ct_file is not None:
275
- # # Save the uploaded file to a temporary location
276
- # ct_temp_path = os.path.join("CT_tempDir", uploaded_ct_file.name)
277
- # with open(ct_temp_path, "wb") as f:
278
- # f.write(uploaded_ct_file.getbuffer())
279
-
280
- # full_ct_temp_path = current_directory +"\\"+ ct_temp_path
281
-
282
- # # Apply evaluation transforms to the DICOM image for model prediction
283
- # image_tensor = eval_transforms(full_ct_temp_path).unsqueeze(0).to(device)
284
-
285
- # # Predict
286
- # with torch.no_grad():
287
- # outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy()
288
- # prob = outputs[0][0]
289
- # CLOTS_CLASSIFICATION = False
290
- # if(prob >= CT_INFERENCE_THRESHOLD):
291
- # CLOTS_CLASSIFICATION=True
292
-
293
- # st.header("CT Classification")
294
- # st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
295
- # st.subheader(f"Confidence : {prob * 100:.1f}%")
296
-
297
- # # Load the original DICOM image for download
298
- # download_image_tensor = original_transforms(full_ct_temp_path).unsqueeze(0).to(device)
299
- # download_image = download_image_tensor.squeeze()
300
-
301
- # # Transform the download image and apply windowing
302
- # transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
303
- # windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
304
-
305
- # # Streamlit button to trigger image download
306
- # image_data = image_to_bytes(Image.fromarray(windowed_download_image))
307
- # st.download_button(
308
- # label="Download CT Image",
309
- # data=image_data,
310
- # file_name="downloaded_ct_image.png",
311
- # mime="image/png"
312
- # )
313
-
314
- # # Load the original DICOM image for display
315
- # display_image_tensor = cam_transforms(full_ct_temp_path).unsqueeze(0).to(device)
316
- # display_image = display_image_tensor.squeeze()
317
-
318
- # # Transform the image and apply windowing
319
- # transformed_image = DICOM_Utils.transform_image_for_display(display_image)
320
- # windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
321
- # st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True)
322
-
323
- # # Expand to three channels
324
- # windowed_image = np.expand_dims(windowed_image, axis=2)
325
- # windowed_image = np.tile(windowed_image, [1, 1, 3])
326
-
327
- # # Ensure both are of float32 type
328
- # windowed_image = windowed_image.astype(np.float32)
329
-
330
- # # Normalize to [0, 1] range
331
- # windowed_image = np.float32(windowed_image) / 255
332
-
333
- # # Build the CAM (Class Activation Map)
334
- # target_layers = [ct_model.model.norm]
335
- # cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
336
- # grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
337
- # grayscale_cam = grayscale_cam[0, :]
338
-
339
- # # Now you can safely call the show_cam_on_image function
340
- # visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
341
- # st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)
 
 
 
 
 
 
 
2
  import os
3
  import io
4
  import torch
5
+ import pydicom
6
  import numpy as np
7
  import streamlit as st
8
 
 
20
  EnsureChannelFirst,
21
  AsDiscrete,
22
  Compose,
 
23
  RandFlip,
24
  RandRotate,
25
  RandZoom,
 
104
  # Evaluation Transforms
105
  eval_transforms = Compose(
106
  [
 
107
  AsChannelFirst(),
108
  ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True),
109
  Resize(spatial_size=SPATIAL_SIZE)
 
113
  # CAM Transforms
114
  cam_transforms = Compose(
115
  [
 
116
  AsChannelFirst(),
117
  Resize(spatial_size=SPATIAL_SIZE)
118
  ]
 
121
  # Original Transforms
122
  original_transforms = Compose(
123
  [
 
124
  AsChannelFirst()
125
  ]
126
  )
 
131
  image.save(byte_stream, format='PNG')
132
  return byte_stream.getvalue()
133
 
134
+ # Convert the file size from bytes to megabytes
135
+ def bytes_to_megabytes(file_size_bytes):
136
+ # Convert bytes to MB (1 MB = 1024 * 1024 bytes)
137
+ file_size_megabytes = round(file_size_bytes / (1024 * 1024), 2)
138
+ return str(file_size_megabytes) + " MB" # Rounding to 2 decimal places for readability
139
 
140
+ def meta_tensor_to_numpy(meta_tensor):
141
+ """
142
+ Convert a PyTorch MetaTensor to a NumPy array
143
+ """
144
+ # Ensure the MetaTensor is on the CPU
145
+ meta_tensor = meta_tensor.cpu()
146
 
147
+ # Convert the MetaTensor to a PyTorch tensor
148
+ torch_tensor = meta_tensor.to(dtype=torch.float32)
149
 
150
+ # Convert the PyTorch tensor to a NumPy array
151
+ numpy_array = torch_tensor.detach().numpy()
152
+
153
+ return numpy_array
154
 
155
  set_determinism(seed=SEED)
156
  torch.manual_seed(SEED)
 
191
 
192
  uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"])
193
  if uploaded_mri_file is not None:
 
 
 
 
 
194
  # Read DICOM file into NumPy array
195
  dicom_data = pydicom.dcmread(uploaded_mri_file)
196
  dicom_array = dicom_data.pixel_array
 
201
  # Then add a channel dimension
202
  dicom_array = dicom_array[:, :, np.newaxis]
203
 
204
+ # To check file details
205
+ file_details = {"File_Name": uploaded_mri_file.name, "File_Type": uploaded_mri_file.type, "File_Size": bytes_to_megabytes(uploaded_mri_file.size), "File_Dimension": str((dicom_array.shape[0],dicom_array.shape[1]))}
206
+ st.write(file_details)
207
 
208
  transformed_array = eval_transforms(dicom_array)
209
 
210
  # Convert to PyTorch tensor and move to device
 
211
  image_tensor = transformed_array.clone().detach().unsqueeze(0).to(device)
212
 
213
  # Predict
 
224
 
225
  # Load the original DICOM image for download
226
  download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device)
227
+ download_image_tensor = download_image_tensor.squeeze()
228
 
229
  # Transform the download image and apply windowing
230
+ download_image_numpy = meta_tensor_to_numpy(download_image_tensor)
231
+ windowed_download_image = DICOM_Utils.apply_windowing(download_image_numpy, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
232
 
233
  # Streamlit button to trigger image download
234
  image_data = image_to_bytes(Image.fromarray(windowed_download_image))
 
241
 
242
  # Load the original DICOM image for display
243
  display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device)
244
+ display_image_tensor = display_image_tensor.squeeze()
245
 
246
  # Transform the image and apply windowing
247
+ display_image_numpy = meta_tensor_to_numpy(display_image_tensor)
248
+ windowed_image = DICOM_Utils.apply_windowing(display_image_numpy, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
249
  st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True)
250
 
251
  # Expand to three channels
 
268
  visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
269
  st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)
270
 
271
+ uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"])
272
+ if uploaded_ct_file is not None:
273
+ # Read DICOM file into NumPy array
274
+ dicom_data = pydicom.dcmread(uploaded_ct_file)
275
+ dicom_array = dicom_data.pixel_array
276
+
277
+ # Convert the data type to float32
278
+ dicom_array = dicom_array.astype(np.float32)
279
+
280
+ # Then add a channel dimension
281
+ dicom_array = dicom_array[:, :, np.newaxis]
282
+
283
+ # To check file details
284
+ file_details = {"File_Name": uploaded_ct_file.name, "File_Type": uploaded_ct_file.type, "File_Size": bytes_to_megabytes(uploaded_ct_file.size), "File_Dimension": str((dicom_array.shape[0],dicom_array.shape[1]))}
285
+ st.write(file_details)
286
+
287
+ transformed_array = eval_transforms(dicom_array)
288
+
289
+ # Predict
290
+ with torch.no_grad():
291
+ outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy()
292
+ prob = outputs[0][0]
293
+ CLOTS_CLASSIFICATION = False
294
+ if(prob >= CT_INFERENCE_THRESHOLD):
295
+ CLOTS_CLASSIFICATION=True
296
+
297
+ st.header("CT Classification")
298
+ st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
299
+ st.subheader(f"Confidence : {prob * 100:.1f}%")
300
+
301
+ # Load the original DICOM image for download
302
+ download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device)
303
+ download_image_tensor = download_image_tensor.squeeze()
304
+
305
+ # Transform the download image and apply windowing
306
+ download_image_numpy = meta_tensor_to_numpy(download_image_tensor)
307
+ windowed_download_image = DICOM_Utils.apply_windowing(download_image_numpy, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
308
+
309
+ # Streamlit button to trigger image download
310
+ image_data = image_to_bytes(Image.fromarray(windowed_download_image))
311
+ st.download_button(
312
+ label="Download CT Image",
313
+ data=image_data,
314
+ file_name="downloaded_ct_image.png",
315
+ mime="image/png"
316
+ )
317
+
318
+ # Load the original DICOM image for display
319
+ display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device)
320
+ display_image_tensor = display_image_tensor.squeeze()
321
+
322
+ # Transform the image and apply windowing
323
+ display_image_numpy = meta_tensor_to_numpy(display_image_tensor)
324
+ windowed_image = DICOM_Utils.apply_windowing(display_image_numpy, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
325
+ st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True)
326
+
327
+ # Expand to three channels
328
+ windowed_image = np.expand_dims(windowed_image, axis=2)
329
+ windowed_image = np.tile(windowed_image, [1, 1, 3])
330
+
331
+ # Ensure both are of float32 type
332
+ windowed_image = windowed_image.astype(np.float32)
333
+
334
+ # Normalize to [0, 1] range
335
+ windowed_image = np.float32(windowed_image) / 255
336
+
337
+ # Build the CAM (Class Activation Map)
338
+ target_layers = [ct_model.model.norm]
339
+ cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=USE_CUDA)
340
+ grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
341
+ grayscale_cam = grayscale_cam[0, :]
342
+
343
+ # Now you can safely call the show_cam_on_image function
344
+ visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
345
+ st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)