# Import required libraries | |
import os | |
import io | |
import torch | |
import tempfile | |
import numpy as np | |
import streamlit as st | |
# Import utility and custom functions | |
from PIL import Image | |
from Util.DICOM import DICOM_Utils | |
from Util.Custom_Model import Build_Custom_Model, reshape_transform | |
# Import additional MONAI and PyTorch Grad-CAM utilities | |
from monai.config import print_config | |
from monai.utils import set_determinism | |
from monai.networks.nets import SEResNet50 | |
from monai.transforms import ( | |
Activations, | |
EnsureChannelFirst, | |
AsDiscrete, | |
Compose, | |
LoadImage, | |
RandFlip, | |
RandRotate, | |
RandZoom, | |
ScaleIntensity, | |
AsChannelFirst, | |
AddChannel, | |
RandSpatialCrop, | |
ScaleIntensityRangePercentiles, | |
Resize, | |
) | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
# (Int) Random seed | |
SEED = 0 | |
# (Int) Model parameters | |
NUM_CLASSES = 1 | |
# (String) CT Model directory | |
CT_MODEL_DIRECTORY = "models/CLOTS/CT" | |
# (String) MRI Model directory | |
MRI_MODEL_DIRECTORY = "models/CLOTS/MRI" | |
# (Boolean) Use custom model | |
CUSTOM_MODEL_FLAG = True | |
# (List[int]) Image size | |
SPATIAL_SIZE = [224, 224] | |
# (String) CT Model file name | |
CT_MODEL_FILE_NAME = "best_metric_model.pth" | |
# (String) MRI Model file name | |
MRI_MODEL_FILE_NAME = "best_metric_model.pth" | |
# (Boolean) List model modules | |
LIST_MODEL_MODULES = False | |
# (String) Model name | |
CT_MODEL_NAME = "swin_base_patch4_window7_224" | |
# (String) Model name | |
MRI_MODEL_NAME = "swin_base_patch4_window7_224" | |
# (Float) Model inference threshold | |
CT_INFERENCE_THRESHOLD = 0.5 | |
# (Float) Model inference threshold | |
MRI_INFERENCE_THRESHOLD = 0.5 | |
# (Int) Display CAM Class ID | |
CAM_CLASS_ID = 0 | |
# (Int) Window Center for image display | |
DEFAULT_CT_WINDOW_CENTER = 40 | |
# (Int) Window Width for image display | |
DEFAULT_CT_WINDOW_WIDTH = 100 | |
# (Int) Window Center for image display | |
DEFAULT_MRI_WINDOW_CENTER = 400 | |
# (Int) Window Width for image display | |
DEFAULT_MRI_WINDOW_WIDTH = 1000 | |
# (Int) Minimum value for Window Center | |
WINDOW_CENTER_MIN = -600 | |
# (Int) Maximum value for Window Center | |
WINDOW_CENTER_MAX = 1000 | |
# (Int) Minimum value for Window Width | |
WINDOW_WIDTH_MIN = 1 | |
# (Int) Maximum value for Window Width | |
WINDOW_WIDTH_MAX = 3000 | |
# Evaluation Transforms | |
eval_transforms = Compose( | |
[ | |
LoadImage(image_only=True), | |
AsChannelFirst(), | |
ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True), | |
Resize(spatial_size=SPATIAL_SIZE) | |
] | |
) | |
# CAM Transforms | |
cam_transforms = Compose( | |
[ | |
LoadImage(image_only=True), | |
AsChannelFirst(), | |
Resize(spatial_size=SPATIAL_SIZE) | |
] | |
) | |
# Original Transforms | |
original_transforms = Compose( | |
[ | |
LoadImage(image_only=True), | |
AsChannelFirst() | |
] | |
) | |
# Function to convert PIL Image to byte stream in PNG format for downloading | |
def image_to_bytes(image): | |
byte_stream = io.BytesIO() | |
image.save(byte_stream, format='PNG') | |
return byte_stream.getvalue() | |
set_determinism(seed=SEED) | |
torch.manual_seed(SEED) | |
# Parameters | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY | |
mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY | |
def load_model(root_dir, model_name, model_file_name): | |
if CUSTOM_MODEL_FLAG: | |
model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device) | |
else: | |
model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device) | |
model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device)) | |
model.eval() | |
return model | |
ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME) | |
mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME) | |
if LIST_MODEL_MODULES: | |
for ct_name, _ in ct_model.named_modules(): | |
print(ct_name) | |
for mri_name, _ in mri_model.named_modules(): | |
print(mri_name) | |
# Initialize Streamlit | |
st.title("Analyze") | |
# Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH | |
st.sidebar.header("Windowing Parameters for DICOM") | |
CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1) | |
CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1) | |
MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1) | |
MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1) | |
uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"]) | |
# if uploaded_ct_file is not None: | |
# # Save the uploaded file to a temporary location | |
# with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file: | |
# temp_file.write(uploaded_ct_file.getvalue()) | |
# # Apply evaluation transforms to the DICOM image for model prediction | |
# image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device) | |
# # Predict | |
# with torch.no_grad(): | |
# outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy() | |
# prob = outputs[0][0] | |
# CLOTS_CLASSIFICATION = False | |
# if(prob >= CT_INFERENCE_THRESHOLD): | |
# CLOTS_CLASSIFICATION=True | |
# st.header("CT Classification") | |
# st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") | |
# st.subheader(f"Confidence : {prob * 100:.1f}%") | |
# # Load the original DICOM image for download | |
# download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device) | |
# download_image = download_image_tensor.squeeze() | |
# # Transform the download image and apply windowing | |
# transformed_download_image = DICOM_Utils.transform_image_for_display(download_image) | |
# windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) | |
# # Streamlit button to trigger image download | |
# image_data = image_to_bytes(Image.fromarray(windowed_download_image)) | |
# st.download_button( | |
# label="Download CT Image", | |
# data=image_data, | |
# file_name="downloaded_ct_image.png", | |
# mime="image/png" | |
# ) | |
# # Load the original DICOM image for display | |
# display_image_tensor = cam_transforms(temp_file.name).unsqueeze(0).to(device) | |
# display_image = display_image_tensor.squeeze() | |
# # Transform the image and apply windowing | |
# transformed_image = DICOM_Utils.transform_image_for_display(display_image) | |
# windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) | |
# st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True) | |
# # Expand to three channels | |
# windowed_image = np.expand_dims(windowed_image, axis=2) | |
# windowed_image = np.tile(windowed_image, [1, 1, 3]) | |
# # Ensure both are of float32 type | |
# windowed_image = windowed_image.astype(np.float32) | |
# # Normalize to [0, 1] range | |
# windowed_image = np.float32(windowed_image) / 255 | |
# # Build the CAM (Class Activation Map) | |
# target_layers = [ct_model.model.norm] | |
# cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True) | |
# grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) | |
# grayscale_cam = grayscale_cam[0, :] | |
# # Now you can safely call the show_cam_on_image function | |
# visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) | |
# st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True) | |
# uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"]) | |
# if uploaded_mri_file is not None: | |
# # Save the uploaded file to a temporary location | |
# with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file: | |
# temp_file.write(uploaded_mri_file.getvalue()) | |
# # Apply evaluation transforms to the DICOM image for model prediction | |
# image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device) | |
# # Predict | |
# with torch.no_grad(): | |
# outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy() | |
# prob = outputs[0][0] | |
# CLOTS_CLASSIFICATION = False | |
# if(prob >= MRI_INFERENCE_THRESHOLD): | |
# CLOTS_CLASSIFICATION=True | |
# st.header("MRI Classification") | |
# st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") | |
# st.subheader(f"Confidence : {prob * 100:.1f}%") | |
# # Load the original DICOM image for download | |
# download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device) | |
# download_image = download_image_tensor.squeeze() | |
# # Transform the download image and apply windowing | |
# transformed_download_image = DICOM_Utils.transform_image_for_display(download_image) | |
# windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) | |
# # Streamlit button to trigger image download | |
# image_data = image_to_bytes(Image.fromarray(windowed_download_image)) | |
# st.download_button( | |
# label="Download MRI Image", | |
# data=image_data, | |
# file_name="downloaded_mri_image.png", | |
# mime="image/png" | |
# ) | |
# # Load the original DICOM image for display | |
# display_image_tensor = cam_transforms(temp_file.name).unsqueeze(0).to(device) | |
# display_image = display_image_tensor.squeeze() | |
# # Transform the image and apply windowing | |
# transformed_image = DICOM_Utils.transform_image_for_display(display_image) | |
# windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) | |
# st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True) | |
# # Expand to three channels | |
# windowed_image = np.expand_dims(windowed_image, axis=2) | |
# windowed_image = np.tile(windowed_image, [1, 1, 3]) | |
# # Ensure both are of float32 type | |
# windowed_image = windowed_image.astype(np.float32) | |
# # Normalize to [0, 1] range | |
# windowed_image = np.float32(windowed_image) / 255 | |
# # Build the CAM (Class Activation Map) | |
# target_layers = [mri_model.model.norm] | |
# cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True) | |
# grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) | |
# grayscale_cam = grayscale_cam[0, :] | |
# # Now you can safely call the show_cam_on_image function | |
# visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) | |
# st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True) | |