MingGatsby commited on
Commit
dcb3d5e
1 Parent(s): 2a9cfd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -132
app.py CHANGED
@@ -172,135 +172,135 @@ MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDO
172
  MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1)
173
 
174
  uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"])
175
- if uploaded_ct_file is not None:
176
- # Save the uploaded file to a temporary location
177
- with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file:
178
- temp_file.write(uploaded_ct_file.getvalue())
179
-
180
- # Apply evaluation transforms to the DICOM image for model prediction
181
- image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device)
182
-
183
- # Predict
184
- with torch.no_grad():
185
- outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy()
186
- prob = outputs[0][0]
187
- CLOTS_CLASSIFICATION = False
188
- if(prob >= CT_INFERENCE_THRESHOLD):
189
- CLOTS_CLASSIFICATION=True
190
-
191
- st.header("CT Classification")
192
- st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
193
- st.subheader(f"Confidence : {prob * 100:.1f}%")
194
-
195
- # Load the original DICOM image for download
196
- download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device)
197
- download_image = download_image_tensor.squeeze()
198
-
199
- # Transform the download image and apply windowing
200
- transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
201
- windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
202
-
203
- # Streamlit button to trigger image download
204
- image_data = image_to_bytes(Image.fromarray(windowed_download_image))
205
- st.download_button(
206
- label="Download CT Image",
207
- data=image_data,
208
- file_name="downloaded_ct_image.png",
209
- mime="image/png"
210
- )
211
-
212
- # Load the original DICOM image for display
213
- display_image_tensor = cam_transforms(temp_file.name).unsqueeze(0).to(device)
214
- display_image = display_image_tensor.squeeze()
215
-
216
- # Transform the image and apply windowing
217
- transformed_image = DICOM_Utils.transform_image_for_display(display_image)
218
- windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
219
- st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True)
220
-
221
- # Expand to three channels
222
- windowed_image = np.expand_dims(windowed_image, axis=2)
223
- windowed_image = np.tile(windowed_image, [1, 1, 3])
224
-
225
- # Ensure both are of float32 type
226
- windowed_image = windowed_image.astype(np.float32)
227
-
228
- # Normalize to [0, 1] range
229
- windowed_image = np.float32(windowed_image) / 255
230
-
231
- # Build the CAM (Class Activation Map)
232
- target_layers = [ct_model.model.norm]
233
- cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
234
- grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
235
- grayscale_cam = grayscale_cam[0, :]
236
-
237
- # Now you can safely call the show_cam_on_image function
238
- visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
239
- st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)
240
-
241
- uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"])
242
- if uploaded_mri_file is not None:
243
- # Save the uploaded file to a temporary location
244
- with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file:
245
- temp_file.write(uploaded_mri_file.getvalue())
246
-
247
- # Apply evaluation transforms to the DICOM image for model prediction
248
- image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device)
249
-
250
- # Predict
251
- with torch.no_grad():
252
- outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy()
253
- prob = outputs[0][0]
254
- CLOTS_CLASSIFICATION = False
255
- if(prob >= MRI_INFERENCE_THRESHOLD):
256
- CLOTS_CLASSIFICATION=True
257
-
258
- st.header("MRI Classification")
259
- st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
260
- st.subheader(f"Confidence : {prob * 100:.1f}%")
261
-
262
- # Load the original DICOM image for download
263
- download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device)
264
- download_image = download_image_tensor.squeeze()
265
-
266
- # Transform the download image and apply windowing
267
- transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
268
- windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
269
-
270
- # Streamlit button to trigger image download
271
- image_data = image_to_bytes(Image.fromarray(windowed_download_image))
272
- st.download_button(
273
- label="Download MRI Image",
274
- data=image_data,
275
- file_name="downloaded_mri_image.png",
276
- mime="image/png"
277
- )
278
-
279
- # Load the original DICOM image for display
280
- display_image_tensor = cam_transforms(temp_file.name).unsqueeze(0).to(device)
281
- display_image = display_image_tensor.squeeze()
282
-
283
- # Transform the image and apply windowing
284
- transformed_image = DICOM_Utils.transform_image_for_display(display_image)
285
- windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
286
- st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True)
287
-
288
- # Expand to three channels
289
- windowed_image = np.expand_dims(windowed_image, axis=2)
290
- windowed_image = np.tile(windowed_image, [1, 1, 3])
291
-
292
- # Ensure both are of float32 type
293
- windowed_image = windowed_image.astype(np.float32)
294
-
295
- # Normalize to [0, 1] range
296
- windowed_image = np.float32(windowed_image) / 255
297
-
298
- # Build the CAM (Class Activation Map)
299
- target_layers = [mri_model.model.norm]
300
- cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
301
- grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
302
- grayscale_cam = grayscale_cam[0, :]
303
-
304
- # Now you can safely call the show_cam_on_image function
305
- visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
306
- st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)
 
172
  MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1)
173
 
174
  uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"])
175
+ # if uploaded_ct_file is not None:
176
+ # # Save the uploaded file to a temporary location
177
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file:
178
+ # temp_file.write(uploaded_ct_file.getvalue())
179
+
180
+ # # Apply evaluation transforms to the DICOM image for model prediction
181
+ # image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device)
182
+
183
+ # # Predict
184
+ # with torch.no_grad():
185
+ # outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy()
186
+ # prob = outputs[0][0]
187
+ # CLOTS_CLASSIFICATION = False
188
+ # if(prob >= CT_INFERENCE_THRESHOLD):
189
+ # CLOTS_CLASSIFICATION=True
190
+
191
+ # st.header("CT Classification")
192
+ # st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
193
+ # st.subheader(f"Confidence : {prob * 100:.1f}%")
194
+
195
+ # # Load the original DICOM image for download
196
+ # download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device)
197
+ # download_image = download_image_tensor.squeeze()
198
+
199
+ # # Transform the download image and apply windowing
200
+ # transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
201
+ # windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
202
+
203
+ # # Streamlit button to trigger image download
204
+ # image_data = image_to_bytes(Image.fromarray(windowed_download_image))
205
+ # st.download_button(
206
+ # label="Download CT Image",
207
+ # data=image_data,
208
+ # file_name="downloaded_ct_image.png",
209
+ # mime="image/png"
210
+ # )
211
+
212
+ # # Load the original DICOM image for display
213
+ # display_image_tensor = cam_transforms(temp_file.name).unsqueeze(0).to(device)
214
+ # display_image = display_image_tensor.squeeze()
215
+
216
+ # # Transform the image and apply windowing
217
+ # transformed_image = DICOM_Utils.transform_image_for_display(display_image)
218
+ # windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
219
+ # st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True)
220
+
221
+ # # Expand to three channels
222
+ # windowed_image = np.expand_dims(windowed_image, axis=2)
223
+ # windowed_image = np.tile(windowed_image, [1, 1, 3])
224
+
225
+ # # Ensure both are of float32 type
226
+ # windowed_image = windowed_image.astype(np.float32)
227
+
228
+ # # Normalize to [0, 1] range
229
+ # windowed_image = np.float32(windowed_image) / 255
230
+
231
+ # # Build the CAM (Class Activation Map)
232
+ # target_layers = [ct_model.model.norm]
233
+ # cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
234
+ # grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
235
+ # grayscale_cam = grayscale_cam[0, :]
236
+
237
+ # # Now you can safely call the show_cam_on_image function
238
+ # visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
239
+ # st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)
240
+
241
+ # uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"])
242
+ # if uploaded_mri_file is not None:
243
+ # # Save the uploaded file to a temporary location
244
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file:
245
+ # temp_file.write(uploaded_mri_file.getvalue())
246
+
247
+ # # Apply evaluation transforms to the DICOM image for model prediction
248
+ # image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device)
249
+
250
+ # # Predict
251
+ # with torch.no_grad():
252
+ # outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy()
253
+ # prob = outputs[0][0]
254
+ # CLOTS_CLASSIFICATION = False
255
+ # if(prob >= MRI_INFERENCE_THRESHOLD):
256
+ # CLOTS_CLASSIFICATION=True
257
+
258
+ # st.header("MRI Classification")
259
+ # st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
260
+ # st.subheader(f"Confidence : {prob * 100:.1f}%")
261
+
262
+ # # Load the original DICOM image for download
263
+ # download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device)
264
+ # download_image = download_image_tensor.squeeze()
265
+
266
+ # # Transform the download image and apply windowing
267
+ # transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
268
+ # windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
269
+
270
+ # # Streamlit button to trigger image download
271
+ # image_data = image_to_bytes(Image.fromarray(windowed_download_image))
272
+ # st.download_button(
273
+ # label="Download MRI Image",
274
+ # data=image_data,
275
+ # file_name="downloaded_mri_image.png",
276
+ # mime="image/png"
277
+ # )
278
+
279
+ # # Load the original DICOM image for display
280
+ # display_image_tensor = cam_transforms(temp_file.name).unsqueeze(0).to(device)
281
+ # display_image = display_image_tensor.squeeze()
282
+
283
+ # # Transform the image and apply windowing
284
+ # transformed_image = DICOM_Utils.transform_image_for_display(display_image)
285
+ # windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
286
+ # st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True)
287
+
288
+ # # Expand to three channels
289
+ # windowed_image = np.expand_dims(windowed_image, axis=2)
290
+ # windowed_image = np.tile(windowed_image, [1, 1, 3])
291
+
292
+ # # Ensure both are of float32 type
293
+ # windowed_image = windowed_image.astype(np.float32)
294
+
295
+ # # Normalize to [0, 1] range
296
+ # windowed_image = np.float32(windowed_image) / 255
297
+
298
+ # # Build the CAM (Class Activation Map)
299
+ # target_layers = [mri_model.model.norm]
300
+ # cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
301
+ # grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
302
+ # grayscale_cam = grayscale_cam[0, :]
303
+
304
+ # # Now you can safely call the show_cam_on_image function
305
+ # visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
306
+ # st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)