# Import required libraries import os import io import torch import pydicom import numpy as np import streamlit as st # Import utility and custom functions from PIL import Image from Util.DICOM import DICOM_Utils from Util.Custom_Model import Build_Custom_Model, reshape_transform # Import additional MONAI and PyTorch Grad-CAM utilities from monai.config import print_config from monai.utils import set_determinism from monai.networks.nets import SEResNet50 from monai.transforms import ( Activations, EnsureChannelFirst, AsDiscrete, Compose, RandFlip, RandRotate, RandZoom, ScaleIntensity, AsChannelFirst, AddChannel, RandSpatialCrop, ScaleIntensityRangePercentiles, Resize, ) from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget # (Int) Random seed SEED = 0 # (Int) Model parameters NUM_CLASSES = 1 # (String) CT Model directory CT_MODEL_DIRECTORY = "models/CLOTS/CT" # (String) MRI Model directory MRI_MODEL_DIRECTORY = "models/CLOTS/MRI" # (Boolean) Use custom model CUSTOM_MODEL_FLAG = True # (List[int]) Image size SPATIAL_SIZE = [224, 224] # (String) CT Model file name CT_MODEL_FILE_NAME = "best_metric_model.pth" # (String) MRI Model file name MRI_MODEL_FILE_NAME = "best_metric_model.pth" # (Boolean) List model modules LIST_MODEL_MODULES = False # (String) Model name CT_MODEL_NAME = "swin_base_patch4_window7_224" # (String) Model name MRI_MODEL_NAME = "swin_base_patch4_window7_224" # (Float) Model inference threshold CT_INFERENCE_THRESHOLD = 0.5 # (Float) Model inference threshold MRI_INFERENCE_THRESHOLD = 0.5 # (Int) Display CAM Class ID CAM_CLASS_ID = 0 # (Int) Window Center for image display DEFAULT_CT_WINDOW_CENTER = 40 # (Int) Window Width for image display DEFAULT_CT_WINDOW_WIDTH = 100 # (Int) Window Center for image display DEFAULT_MRI_WINDOW_CENTER = 400 # (Int) Window Width for image display DEFAULT_MRI_WINDOW_WIDTH = 1000 # (Int) Minimum value for Window Center WINDOW_CENTER_MIN = -600 # (Int) Maximum value for Window Center WINDOW_CENTER_MAX = 1000 # (Int) Minimum value for Window Width WINDOW_WIDTH_MIN = 1 # (Int) Maximum value for Window Width WINDOW_WIDTH_MAX = 3000 # Evaluation Transforms eval_transforms = Compose( [ AsChannelFirst(), ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True), Resize(spatial_size=SPATIAL_SIZE) ] ) # CAM Transforms cam_transforms = Compose( [ AsChannelFirst(), Resize(spatial_size=SPATIAL_SIZE) ] ) # Original Transforms original_transforms = Compose( [ AsChannelFirst() ] ) # Function to convert PIL Image to byte stream in PNG format for downloading def image_to_bytes(image): byte_stream = io.BytesIO() image.save(byte_stream, format='PNG') return byte_stream.getvalue() # Convert the file size from bytes to megabytes def bytes_to_megabytes(file_size_bytes): # Convert bytes to MB (1 MB = 1024 * 1024 bytes) file_size_megabytes = round(file_size_bytes / (1024 * 1024), 2) return str(file_size_megabytes) + " MB" # Rounding to 2 decimal places for readability def meta_tensor_to_numpy(meta_tensor): """ Convert a PyTorch MetaTensor to a NumPy array """ # Ensure the MetaTensor is on the CPU meta_tensor = meta_tensor.cpu() # Convert the MetaTensor to a PyTorch tensor torch_tensor = meta_tensor.to(dtype=torch.float32) # Convert the PyTorch tensor to a NumPy array numpy_array = torch_tensor.detach().numpy() return numpy_array set_determinism(seed=SEED) torch.manual_seed(SEED) # Parameters device = torch.device("cuda" if torch.cuda.is_available() else "cpu") USE_CUDA = False if device == torch.device("cuda"): USE_CUDA = True def load_model(root_dir, model_name, model_file_name): if CUSTOM_MODEL_FLAG: model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device) else: model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device) model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device)) model.eval() return model ct_model = load_model(CT_MODEL_DIRECTORY, CT_MODEL_NAME, CT_MODEL_FILE_NAME) mri_model = load_model(MRI_MODEL_DIRECTORY, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME) if LIST_MODEL_MODULES: for ct_name, _ in ct_model.named_modules(): print(ct_name) for mri_name, _ in mri_model.named_modules(): print(mri_name) # Initialize Streamlit st.title("Analyze") # Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH st.sidebar.header("Windowing Parameters for DICOM") MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1) MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1) CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1) CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1) uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"]) if uploaded_mri_file is not None: # Read DICOM file into NumPy array dicom_data = pydicom.dcmread(uploaded_mri_file) dicom_array = dicom_data.pixel_array # Convert the data type to float32 dicom_array = dicom_array.astype(np.float32) # Then add a channel dimension dicom_array = dicom_array[:, :, np.newaxis] # To check file details file_details = {"File_Name": uploaded_mri_file.name, "File_Type": uploaded_mri_file.type, "File_Size": bytes_to_megabytes(uploaded_mri_file.size), "File_Dimension": str((dicom_array.shape[0],dicom_array.shape[1]))} st.write(file_details) transformed_array = eval_transforms(dicom_array) # Convert to PyTorch tensor and move to device image_tensor = transformed_array.clone().detach().unsqueeze(0).to(device) # Predict with torch.no_grad(): outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy() prob = outputs[0][0] CLOTS_CLASSIFICATION = False if(prob >= MRI_INFERENCE_THRESHOLD): CLOTS_CLASSIFICATION=True st.header("MRI Classification") st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") st.subheader(f"Confidence : {prob * 100:.1f}%") # Load the original DICOM image for download download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device) download_image_tensor = download_image_tensor.squeeze() # Transform the download image and apply windowing download_image_numpy = meta_tensor_to_numpy(download_image_tensor) windowed_download_image = DICOM_Utils.apply_windowing(download_image_numpy, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) # Streamlit button to trigger image download image_data = image_to_bytes(Image.fromarray(windowed_download_image)) st.download_button( label="Download MRI Image", data=image_data, file_name="downloaded_mri_image.png", mime="image/png" ) # Load the original DICOM image for display display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device) display_image_tensor = display_image_tensor.squeeze() # Transform the image and apply windowing display_image_numpy = meta_tensor_to_numpy(display_image_tensor) windowed_image = DICOM_Utils.apply_windowing(display_image_numpy, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True) # Expand to three channels windowed_image = np.expand_dims(windowed_image, axis=2) windowed_image = np.tile(windowed_image, [1, 1, 3]) # Ensure both are of float32 type windowed_image = windowed_image.astype(np.float32) # Normalize to [0, 1] range windowed_image = np.float32(windowed_image) / 255 # Build the CAM (Class Activation Map) target_layers = [mri_model.model.norm] cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=USE_CUDA) grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) grayscale_cam = grayscale_cam[0, :] # Now you can safely call the show_cam_on_image function visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True) uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"]) if uploaded_ct_file is not None: # Read DICOM file into NumPy array dicom_data = pydicom.dcmread(uploaded_ct_file) dicom_array = dicom_data.pixel_array # Convert the data type to float32 dicom_array = dicom_array.astype(np.float32) # Then add a channel dimension dicom_array = dicom_array[:, :, np.newaxis] # To check file details file_details = {"File_Name": uploaded_ct_file.name, "File_Type": uploaded_ct_file.type, "File_Size": bytes_to_megabytes(uploaded_ct_file.size), "File_Dimension": str((dicom_array.shape[0],dicom_array.shape[1]))} st.write(file_details) transformed_array = eval_transforms(dicom_array) # Predict with torch.no_grad(): outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy() prob = outputs[0][0] CLOTS_CLASSIFICATION = False if(prob >= CT_INFERENCE_THRESHOLD): CLOTS_CLASSIFICATION=True st.header("CT Classification") st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") st.subheader(f"Confidence : {prob * 100:.1f}%") # Load the original DICOM image for download download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device) download_image_tensor = download_image_tensor.squeeze() # Transform the download image and apply windowing download_image_numpy = meta_tensor_to_numpy(download_image_tensor) windowed_download_image = DICOM_Utils.apply_windowing(download_image_numpy, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) # Streamlit button to trigger image download image_data = image_to_bytes(Image.fromarray(windowed_download_image)) st.download_button( label="Download CT Image", data=image_data, file_name="downloaded_ct_image.png", mime="image/png" ) # Load the original DICOM image for display display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device) display_image_tensor = display_image_tensor.squeeze() # Transform the image and apply windowing display_image_numpy = meta_tensor_to_numpy(display_image_tensor) windowed_image = DICOM_Utils.apply_windowing(display_image_numpy, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True) # Expand to three channels windowed_image = np.expand_dims(windowed_image, axis=2) windowed_image = np.tile(windowed_image, [1, 1, 3]) # Ensure both are of float32 type windowed_image = windowed_image.astype(np.float32) # Normalize to [0, 1] range windowed_image = np.float32(windowed_image) / 255 # Build the CAM (Class Activation Map) target_layers = [ct_model.model.norm] cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=USE_CUDA) grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) grayscale_cam = grayscale_cam[0, :] # Now you can safely call the show_cam_on_image function visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)