# Import required libraries import os import io import torch import tempfile import numpy as np import streamlit as st # Import utility and custom functions from PIL import Image from Util.DICOM import DICOM_Utils from Util.Custom_Model import Build_Custom_Model, reshape_transform # Import additional MONAI and PyTorch Grad-CAM utilities from monai.config import print_config from monai.utils import set_determinism from monai.networks.nets import SEResNet50 from monai.transforms import ( Activations, EnsureChannelFirst, AsDiscrete, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, AsChannelFirst, AddChannel, RandSpatialCrop, ScaleIntensityRangePercentiles, Resize, ) from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget # (Int) Random seed SEED = 0 # (Int) Model parameters NUM_CLASSES = 1 # (String) CT Model directory CT_MODEL_DIRECTORY = "models/CLOTS/CT" # (String) MRI Model directory MRI_MODEL_DIRECTORY = "models/CLOTS/MRI" # (Boolean) Use custom model CUSTOM_MODEL_FLAG = True # (List[int]) Image size SPATIAL_SIZE = [224, 224] # (String) CT Model file name CT_MODEL_FILE_NAME = "best_metric_model.pth" # (String) MRI Model file name MRI_MODEL_FILE_NAME = "best_metric_model.pth" # (Boolean) List model modules LIST_MODEL_MODULES = False # (String) Model name CT_MODEL_NAME = "swin_base_patch4_window7_224" # (String) Model name MRI_MODEL_NAME = "swin_base_patch4_window7_224" # (Float) Model inference threshold CT_INFERENCE_THRESHOLD = 0.5 # (Float) Model inference threshold MRI_INFERENCE_THRESHOLD = 0.5 # (Int) Display CAM Class ID CAM_CLASS_ID = 0 # (Int) Window Center for image display DEFAULT_CT_WINDOW_CENTER = 40 # (Int) Window Width for image display DEFAULT_CT_WINDOW_WIDTH = 100 # (Int) Window Center for image display DEFAULT_MRI_WINDOW_CENTER = 400 # (Int) Window Width for image display DEFAULT_MRI_WINDOW_WIDTH = 1000 # (Int) Minimum value for Window Center WINDOW_CENTER_MIN = -600 # (Int) Maximum value for Window Center WINDOW_CENTER_MAX = 1000 # (Int) Minimum value for Window Width WINDOW_WIDTH_MIN = 1 # (Int) Maximum value for Window Width WINDOW_WIDTH_MAX = 3000 # Evaluation Transforms eval_transforms = Compose( [ LoadImage(image_only=True), AsChannelFirst(), ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True), Resize(spatial_size=SPATIAL_SIZE) ] ) # CAM Original Transforms cam_original_transforms = Compose( [ LoadImage(image_only=True), AsChannelFirst(), Resize(spatial_size=SPATIAL_SIZE) ] ) # CAM Original Transforms original_transforms = Compose( [ LoadImage(image_only=True), AsChannelFirst() ] ) # Function to convert PIL Image to byte stream in PNG format for downloading def image_to_bytes(image): byte_stream = io.BytesIO() image.save(byte_stream, format='PNG') return byte_stream.getvalue() set_determinism(seed=SEED) torch.manual_seed(SEED) # Parameters device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY def load_model(root_dir, model_name, model_file_name): if CUSTOM_MODEL_FLAG: model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device) else: model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device) model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device))) model.eval() return model ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME) mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME) if LIST_MODEL_MODULES: for ct_name, _ in ct_model.named_modules(): print(ct_name) for mri_name, _ in mri_model.named_modules(): print(mri_name) # Initialize Streamlit st.title("Analyze") # Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH st.sidebar.header("Windowing Parameters for DICOM") CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1) CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1) MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1) MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1) uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"]) if uploaded_ct_file is not None: # Save the uploaded file to a temporary location with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file: temp_file.write(uploaded_ct_file.getvalue()) # Apply evaluation transforms to the DICOM image for model prediction image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device) # Predict with torch.no_grad(): outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy() prob = outputs[0][0] CLOTS_CLASSIFICATION = False if(prob >= CT_INFERENCE_THRESHOLD): CLOTS_CLASSIFICATION=True st.header("CT Classification") st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") st.subheader(f"Confidence : {prob * 100:.1f}%") # Load the original DICOM image for download download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device) download_image = download_image_tensor.squeeze() # Transform the download image and apply windowing transformed_download_image = DICOM_Utils.transform_image_for_display(download_image) windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) # Streamlit button to trigger image download image_data = image_to_bytes(Image.fromarray(windowed_download_image)) st.download_button( label="Download CT Image", data=image_data, file_name="downloaded_ct_image.png", mime="image/png" ) # Load the original DICOM image for display display_image_tensor = cam_original_transforms(temp_file.name).unsqueeze(0).to(device) display_image = display_image_tensor.squeeze() # Transform the image and apply windowing transformed_image = DICOM_Utils.transform_image_for_display(display_image) windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True) # Expand to three channels windowed_image = np.expand_dims(windowed_image, axis=2) windowed_image = np.tile(windowed_image, [1, 1, 3]) # Ensure both are of float32 type windowed_image = windowed_image.astype(np.float32) # Normalize to [0, 1] range windowed_image = np.float32(windowed_image) / 255 # Build the CAM (Class Activation Map) target_layers = [ct_model.model.norm] cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True) grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) grayscale_cam = grayscale_cam[0, :] # Now you can safely call the show_cam_on_image function visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True) uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"]) if uploaded_mri_file is not None: # Save the uploaded file to a temporary location with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file: temp_file.write(uploaded_mri_file.getvalue()) # Apply evaluation transforms to the DICOM image for model prediction image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device) # Predict with torch.no_grad(): outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy() prob = outputs[0][0] CLOTS_CLASSIFICATION = False if(prob >= MRI_INFERENCE_THRESHOLD): CLOTS_CLASSIFICATION=True st.header("MRI Classification") st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") st.subheader(f"Confidence : {prob * 100:.1f}%") # Load the original DICOM image for download download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device) download_image = download_image_tensor.squeeze() # Transform the download image and apply windowing transformed_download_image = DICOM_Utils.transform_image_for_display(download_image) windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) # Streamlit button to trigger image download image_data = image_to_bytes(Image.fromarray(windowed_download_image)) st.download_button( label="Download MRI Image", data=image_data, file_name="downloaded_mri_image.png", mime="image/png" ) # Load the original DICOM image for display display_image_tensor = cam_original_transforms(temp_file.name).unsqueeze(0).to(device) display_image = display_image_tensor.squeeze() # Transform the image and apply windowing transformed_image = DICOM_Utils.transform_image_for_display(display_image) windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True) # Expand to three channels windowed_image = np.expand_dims(windowed_image, axis=2) windowed_image = np.tile(windowed_image, [1, 1, 3]) # Ensure both are of float32 type windowed_image = windowed_image.astype(np.float32) # Normalize to [0, 1] range windowed_image = np.float32(windowed_image) / 255 # Build the CAM (Class Activation Map) target_layers = [mri_model.model.norm] cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True) grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) grayscale_cam = grayscale_cam[0, :] # Now you can safely call the show_cam_on_image function visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)