MingGatsby commited on
Commit
1b8a13f
1 Parent(s): b93204d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -139
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import io
4
  import torch
5
- import tempfile
6
  import numpy as np
7
  import streamlit as st
8
 
@@ -135,25 +135,36 @@ def image_to_bytes(image):
135
  image.save(byte_stream, format='PNG')
136
  return byte_stream.getvalue()
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  set_determinism(seed=SEED)
139
  torch.manual_seed(SEED)
140
 
141
  # Parameters
142
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
- ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY
144
- mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY
145
 
146
  def load_model(root_dir, model_name, model_file_name):
147
  if CUSTOM_MODEL_FLAG:
148
  model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device)
149
  else:
150
  model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device)
151
- model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device))
152
  model.eval()
153
  return model
154
 
155
- ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME)
156
- mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME)
157
  if LIST_MODEL_MODULES:
158
  for ct_name, _ in ct_model.named_modules():
159
  print(ct_name)
@@ -166,145 +177,147 @@ st.title("Analyze")
166
 
167
  # Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH
168
  st.sidebar.header("Windowing Parameters for DICOM")
169
- CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1)
170
- CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1)
171
  MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1)
172
  MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"])
175
  if uploaded_ct_file is not None:
176
  # Save the uploaded file to a temporary location
177
- with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file:
178
- temp_file.write(uploaded_ct_file.getvalue())
 
179
 
180
- print(tempfile.name)
181
 
182
  # Apply evaluation transforms to the DICOM image for model prediction
183
- image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device)
184
-
185
- print(image_tensor)
186
-
187
- # # Predict
188
- # with torch.no_grad():
189
- # outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy()
190
- # prob = outputs[0][0]
191
- # CLOTS_CLASSIFICATION = False
192
- # if(prob >= CT_INFERENCE_THRESHOLD):
193
- # CLOTS_CLASSIFICATION=True
194
-
195
- # st.header("CT Classification")
196
- # st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
197
- # st.subheader(f"Confidence : {prob * 100:.1f}%")
198
-
199
- # # Load the original DICOM image for download
200
- # download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device)
201
- # download_image = download_image_tensor.squeeze()
202
-
203
- # # Transform the download image and apply windowing
204
- # transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
205
- # windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
206
-
207
- # # Streamlit button to trigger image download
208
- # image_data = image_to_bytes(Image.fromarray(windowed_download_image))
209
- # st.download_button(
210
- # label="Download CT Image",
211
- # data=image_data,
212
- # file_name="downloaded_ct_image.png",
213
- # mime="image/png"
214
- # )
215
-
216
- # # Load the original DICOM image for display
217
- # display_image_tensor = cam_transforms(temp_file.name).unsqueeze(0).to(device)
218
- # display_image = display_image_tensor.squeeze()
219
-
220
- # # Transform the image and apply windowing
221
- # transformed_image = DICOM_Utils.transform_image_for_display(display_image)
222
- # windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
223
- # st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True)
224
-
225
- # # Expand to three channels
226
- # windowed_image = np.expand_dims(windowed_image, axis=2)
227
- # windowed_image = np.tile(windowed_image, [1, 1, 3])
228
-
229
- # # Ensure both are of float32 type
230
- # windowed_image = windowed_image.astype(np.float32)
231
-
232
- # # Normalize to [0, 1] range
233
- # windowed_image = np.float32(windowed_image) / 255
234
-
235
- # # Build the CAM (Class Activation Map)
236
- # target_layers = [ct_model.model.norm]
237
- # cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
238
- # grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
239
- # grayscale_cam = grayscale_cam[0, :]
240
-
241
- # # Now you can safely call the show_cam_on_image function
242
- # visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
243
- # st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)
244
-
245
- # uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"])
246
- # if uploaded_mri_file is not None:
247
- # # Save the uploaded file to a temporary location
248
- # with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file:
249
- # temp_file.write(uploaded_mri_file.getvalue())
250
-
251
- # # Apply evaluation transforms to the DICOM image for model prediction
252
- # image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device)
253
-
254
- # # Predict
255
- # with torch.no_grad():
256
- # outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy()
257
- # prob = outputs[0][0]
258
- # CLOTS_CLASSIFICATION = False
259
- # if(prob >= MRI_INFERENCE_THRESHOLD):
260
- # CLOTS_CLASSIFICATION=True
261
-
262
- # st.header("MRI Classification")
263
- # st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
264
- # st.subheader(f"Confidence : {prob * 100:.1f}%")
265
-
266
- # # Load the original DICOM image for download
267
- # download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device)
268
- # download_image = download_image_tensor.squeeze()
269
-
270
- # # Transform the download image and apply windowing
271
- # transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
272
- # windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
273
-
274
- # # Streamlit button to trigger image download
275
- # image_data = image_to_bytes(Image.fromarray(windowed_download_image))
276
- # st.download_button(
277
- # label="Download MRI Image",
278
- # data=image_data,
279
- # file_name="downloaded_mri_image.png",
280
- # mime="image/png"
281
- # )
282
-
283
- # # Load the original DICOM image for display
284
- # display_image_tensor = cam_transforms(temp_file.name).unsqueeze(0).to(device)
285
- # display_image = display_image_tensor.squeeze()
286
-
287
- # # Transform the image and apply windowing
288
- # transformed_image = DICOM_Utils.transform_image_for_display(display_image)
289
- # windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
290
- # st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True)
291
-
292
- # # Expand to three channels
293
- # windowed_image = np.expand_dims(windowed_image, axis=2)
294
- # windowed_image = np.tile(windowed_image, [1, 1, 3])
295
-
296
- # # Ensure both are of float32 type
297
- # windowed_image = windowed_image.astype(np.float32)
298
-
299
- # # Normalize to [0, 1] range
300
- # windowed_image = np.float32(windowed_image) / 255
301
-
302
- # # Build the CAM (Class Activation Map)
303
- # target_layers = [mri_model.model.norm]
304
- # cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
305
- # grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
306
- # grayscale_cam = grayscale_cam[0, :]
307
-
308
- # # Now you can safely call the show_cam_on_image function
309
- # visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
310
- # st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)
 
2
  import os
3
  import io
4
  import torch
5
+ # import shutil
6
  import numpy as np
7
  import streamlit as st
8
 
 
135
  image.save(byte_stream, format='PNG')
136
  return byte_stream.getvalue()
137
 
138
+ # if os.path.exists("tempDir"):
139
+ # shutil.rmtree(os.path.join("tempDir"))
140
+
141
+ def create_dir(dirname: str):
142
+ if not os.path.exists(dirname):
143
+ os.makedirs(dirname, exist_ok=True)
144
+
145
+ create_dir("CT_tempDir")
146
+ create_dir("MRI_tempDir")
147
+
148
+ # Get the current working directory
149
+ current_directory = os.getcwd()
150
+
151
  set_determinism(seed=SEED)
152
  torch.manual_seed(SEED)
153
 
154
  # Parameters
155
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
156
 
157
  def load_model(root_dir, model_name, model_file_name):
158
  if CUSTOM_MODEL_FLAG:
159
  model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device)
160
  else:
161
  model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device)
162
+ model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device))
163
  model.eval()
164
  return model
165
 
166
+ ct_model = load_model(CT_MODEL_DIRECTORY, CT_MODEL_NAME, CT_MODEL_FILE_NAME)
167
+ mri_model = load_model(MRI_MODEL_DIRECTORY, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME)
168
  if LIST_MODEL_MODULES:
169
  for ct_name, _ in ct_model.named_modules():
170
  print(ct_name)
 
177
 
178
  # Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH
179
  st.sidebar.header("Windowing Parameters for DICOM")
 
 
180
  MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1)
181
  MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1)
182
+ CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1)
183
+ CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1)
184
+
185
+ uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"])
186
+ if uploaded_mri_file is not None:
187
+ # Save the uploaded file to a temporary location
188
+ mri_temp_path = os.path.join("MRI_tempDir", uploaded_mri_file.name)
189
+ with open(mri_temp_path, "wb") as f:
190
+ f.write(uploaded_mri_file.getbuffer())
191
+
192
+ full_mri_temp_path = current_directory +"\\"+ mri_temp_path
193
+
194
+ # Apply evaluation transforms to the DICOM image for model prediction
195
+ image_tensor = eval_transforms(full_mri_temp_path).unsqueeze(0).to(device)
196
+
197
+ # Predict
198
+ with torch.no_grad():
199
+ outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy()
200
+ prob = outputs[0][0]
201
+ CLOTS_CLASSIFICATION = False
202
+ if(prob >= MRI_INFERENCE_THRESHOLD):
203
+ CLOTS_CLASSIFICATION=True
204
+
205
+ st.header("MRI Classification")
206
+ st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
207
+ st.subheader(f"Confidence : {prob * 100:.1f}%")
208
+
209
+ # Load the original DICOM image for download
210
+ download_image_tensor = original_transforms(full_mri_temp_path).unsqueeze(0).to(device)
211
+ download_image = download_image_tensor.squeeze()
212
+
213
+ # Transform the download image and apply windowing
214
+ transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
215
+ windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
216
+
217
+ # Streamlit button to trigger image download
218
+ image_data = image_to_bytes(Image.fromarray(windowed_download_image))
219
+ st.download_button(
220
+ label="Download MRI Image",
221
+ data=image_data,
222
+ file_name="downloaded_mri_image.png",
223
+ mime="image/png"
224
+ )
225
+
226
+ # Load the original DICOM image for display
227
+ display_image_tensor = cam_transforms(full_mri_temp_path).unsqueeze(0).to(device)
228
+ display_image = display_image_tensor.squeeze()
229
+
230
+ # Transform the image and apply windowing
231
+ transformed_image = DICOM_Utils.transform_image_for_display(display_image)
232
+ windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
233
+ st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True)
234
+
235
+ # Expand to three channels
236
+ windowed_image = np.expand_dims(windowed_image, axis=2)
237
+ windowed_image = np.tile(windowed_image, [1, 1, 3])
238
+
239
+ # Ensure both are of float32 type
240
+ windowed_image = windowed_image.astype(np.float32)
241
+
242
+ # Normalize to [0, 1] range
243
+ windowed_image = np.float32(windowed_image) / 255
244
+
245
+ # Build the CAM (Class Activation Map)
246
+ target_layers = [mri_model.model.norm]
247
+ cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
248
+ grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
249
+ grayscale_cam = grayscale_cam[0, :]
250
+
251
+ # Now you can safely call the show_cam_on_image function
252
+ visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
253
+ st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)
254
 
255
  uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"])
256
  if uploaded_ct_file is not None:
257
  # Save the uploaded file to a temporary location
258
+ ct_temp_path = os.path.join("CT_tempDir", uploaded_ct_file.name)
259
+ with open(ct_temp_path, "wb") as f:
260
+ f.write(uploaded_ct_file.getbuffer())
261
 
262
+ full_ct_temp_path = current_directory +"\\"+ ct_temp_path
263
 
264
  # Apply evaluation transforms to the DICOM image for model prediction
265
+ image_tensor = eval_transforms(full_ct_temp_path).unsqueeze(0).to(device)
266
+
267
+ # Predict
268
+ with torch.no_grad():
269
+ outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy()
270
+ prob = outputs[0][0]
271
+ CLOTS_CLASSIFICATION = False
272
+ if(prob >= CT_INFERENCE_THRESHOLD):
273
+ CLOTS_CLASSIFICATION=True
274
+
275
+ st.header("CT Classification")
276
+ st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
277
+ st.subheader(f"Confidence : {prob * 100:.1f}%")
278
+
279
+ # Load the original DICOM image for download
280
+ download_image_tensor = original_transforms(full_ct_temp_path).unsqueeze(0).to(device)
281
+ download_image = download_image_tensor.squeeze()
282
+
283
+ # Transform the download image and apply windowing
284
+ transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
285
+ windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
286
+
287
+ # Streamlit button to trigger image download
288
+ image_data = image_to_bytes(Image.fromarray(windowed_download_image))
289
+ st.download_button(
290
+ label="Download CT Image",
291
+ data=image_data,
292
+ file_name="downloaded_ct_image.png",
293
+ mime="image/png"
294
+ )
295
+
296
+ # Load the original DICOM image for display
297
+ display_image_tensor = cam_transforms(full_ct_temp_path).unsqueeze(0).to(device)
298
+ display_image = display_image_tensor.squeeze()
299
+
300
+ # Transform the image and apply windowing
301
+ transformed_image = DICOM_Utils.transform_image_for_display(display_image)
302
+ windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
303
+ st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True)
304
+
305
+ # Expand to three channels
306
+ windowed_image = np.expand_dims(windowed_image, axis=2)
307
+ windowed_image = np.tile(windowed_image, [1, 1, 3])
308
+
309
+ # Ensure both are of float32 type
310
+ windowed_image = windowed_image.astype(np.float32)
311
+
312
+ # Normalize to [0, 1] range
313
+ windowed_image = np.float32(windowed_image) / 255
314
+
315
+ # Build the CAM (Class Activation Map)
316
+ target_layers = [ct_model.model.norm]
317
+ cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
318
+ grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
319
+ grayscale_cam = grayscale_cam[0, :]
320
+
321
+ # Now you can safely call the show_cam_on_image function
322
+ visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
323
+ st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)