|
|
|
import os |
|
import io |
|
import torch |
|
import tempfile |
|
import numpy as np |
|
import streamlit as st |
|
|
|
|
|
from PIL import Image |
|
from Util.DICOM import DICOM_Utils |
|
from Util.Custom_Model import Build_Custom_Model, reshape_transform |
|
|
|
|
|
from monai.config import print_config |
|
from monai.utils import set_determinism |
|
from monai.networks.nets import SEResNet50 |
|
from monai.transforms import ( |
|
Activations, |
|
EnsureChannelFirst, |
|
AsDiscrete, |
|
Compose, |
|
LoadImage, |
|
RandFlip, |
|
RandRotate, |
|
RandZoom, |
|
ScaleIntensity, |
|
AsChannelFirst, |
|
AddChannel, |
|
RandSpatialCrop, |
|
ScaleIntensityRangePercentiles, |
|
Resize, |
|
) |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
|
|
|
|
|
|
SEED = 0 |
|
|
|
|
|
NUM_CLASSES = 1 |
|
|
|
|
|
CT_MODEL_DIRECTORY = "models/CLOTS/CT" |
|
|
|
|
|
MRI_MODEL_DIRECTORY = "models/CLOTS/MRI" |
|
|
|
|
|
CUSTOM_MODEL_FLAG = True |
|
|
|
|
|
SPATIAL_SIZE = [224, 224] |
|
|
|
|
|
CT_MODEL_FILE_NAME = "best_metric_model.pth" |
|
|
|
|
|
MRI_MODEL_FILE_NAME = "best_metric_model.pth" |
|
|
|
|
|
LIST_MODEL_MODULES = False |
|
|
|
|
|
CT_MODEL_NAME = "swin_base_patch4_window7_224" |
|
|
|
|
|
MRI_MODEL_NAME = "swin_base_patch4_window7_224" |
|
|
|
|
|
CT_INFERENCE_THRESHOLD = 0.5 |
|
|
|
|
|
MRI_INFERENCE_THRESHOLD = 0.5 |
|
|
|
|
|
CAM_CLASS_ID = 0 |
|
|
|
|
|
DEFAULT_CT_WINDOW_CENTER = 40 |
|
|
|
|
|
DEFAULT_CT_WINDOW_WIDTH = 100 |
|
|
|
|
|
DEFAULT_MRI_WINDOW_CENTER = 400 |
|
|
|
|
|
DEFAULT_MRI_WINDOW_WIDTH = 1000 |
|
|
|
|
|
WINDOW_CENTER_MIN = -600 |
|
|
|
|
|
WINDOW_CENTER_MAX = 1000 |
|
|
|
|
|
WINDOW_WIDTH_MIN = 1 |
|
|
|
|
|
WINDOW_WIDTH_MAX = 3000 |
|
|
|
|
|
eval_transforms = Compose( |
|
[ |
|
LoadImage(image_only=True), |
|
AsChannelFirst(), |
|
ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True), |
|
Resize(spatial_size=SPATIAL_SIZE) |
|
] |
|
) |
|
|
|
|
|
cam_original_transforms = Compose( |
|
[ |
|
LoadImage(image_only=True), |
|
AsChannelFirst(), |
|
Resize(spatial_size=SPATIAL_SIZE) |
|
] |
|
) |
|
|
|
|
|
original_transforms = Compose( |
|
[ |
|
LoadImage(image_only=True), |
|
AsChannelFirst() |
|
] |
|
) |
|
|
|
|
|
def image_to_bytes(image): |
|
byte_stream = io.BytesIO() |
|
image.save(byte_stream, format='PNG') |
|
return byte_stream.getvalue() |
|
|
|
set_determinism(seed=SEED) |
|
torch.manual_seed(SEED) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY |
|
mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY |
|
|
|
def load_model(root_dir, model_name, model_file_name, map_location): |
|
if CUSTOM_MODEL_FLAG: |
|
model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device) |
|
else: |
|
model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device) |
|
model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=map_location))) |
|
model.eval() |
|
return model |
|
|
|
ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME, device) |
|
mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME, device) |
|
if LIST_MODEL_MODULES: |
|
for ct_name, _ in ct_model.named_modules(): |
|
print(ct_name) |
|
|
|
for mri_name, _ in mri_model.named_modules(): |
|
print(mri_name) |
|
|
|
|
|
st.title("Analyze") |
|
|
|
|
|
st.sidebar.header("Windowing Parameters for DICOM") |
|
CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1) |
|
CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1) |
|
MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1) |
|
MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1) |
|
|
|
uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"]) |
|
if uploaded_ct_file is not None: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file: |
|
temp_file.write(uploaded_ct_file.getvalue()) |
|
|
|
|
|
image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy() |
|
prob = outputs[0][0] |
|
CLOTS_CLASSIFICATION = False |
|
if(prob >= CT_INFERENCE_THRESHOLD): |
|
CLOTS_CLASSIFICATION=True |
|
|
|
st.header("CT Classification") |
|
st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") |
|
st.subheader(f"Confidence : {prob * 100:.1f}%") |
|
|
|
|
|
download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device) |
|
download_image = download_image_tensor.squeeze() |
|
|
|
|
|
transformed_download_image = DICOM_Utils.transform_image_for_display(download_image) |
|
windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) |
|
|
|
|
|
image_data = image_to_bytes(Image.fromarray(windowed_download_image)) |
|
st.download_button( |
|
label="Download CT Image", |
|
data=image_data, |
|
file_name="downloaded_ct_image.png", |
|
mime="image/png" |
|
) |
|
|
|
|
|
display_image_tensor = cam_original_transforms(temp_file.name).unsqueeze(0).to(device) |
|
display_image = display_image_tensor.squeeze() |
|
|
|
|
|
transformed_image = DICOM_Utils.transform_image_for_display(display_image) |
|
windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) |
|
st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True) |
|
|
|
|
|
windowed_image = np.expand_dims(windowed_image, axis=2) |
|
windowed_image = np.tile(windowed_image, [1, 1, 3]) |
|
|
|
|
|
windowed_image = windowed_image.astype(np.float32) |
|
|
|
|
|
windowed_image = np.float32(windowed_image) / 255 |
|
|
|
|
|
target_layers = [ct_model.model.norm] |
|
cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True) |
|
grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) |
|
grayscale_cam = grayscale_cam[0, :] |
|
|
|
|
|
visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) |
|
st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True) |
|
|
|
uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"]) |
|
if uploaded_mri_file is not None: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file: |
|
temp_file.write(uploaded_mri_file.getvalue()) |
|
|
|
|
|
image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy() |
|
prob = outputs[0][0] |
|
CLOTS_CLASSIFICATION = False |
|
if(prob >= MRI_INFERENCE_THRESHOLD): |
|
CLOTS_CLASSIFICATION=True |
|
|
|
st.header("MRI Classification") |
|
st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") |
|
st.subheader(f"Confidence : {prob * 100:.1f}%") |
|
|
|
|
|
download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device) |
|
download_image = download_image_tensor.squeeze() |
|
|
|
|
|
transformed_download_image = DICOM_Utils.transform_image_for_display(download_image) |
|
windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) |
|
|
|
|
|
image_data = image_to_bytes(Image.fromarray(windowed_download_image)) |
|
st.download_button( |
|
label="Download MRI Image", |
|
data=image_data, |
|
file_name="downloaded_mri_image.png", |
|
mime="image/png" |
|
) |
|
|
|
|
|
display_image_tensor = cam_original_transforms(temp_file.name).unsqueeze(0).to(device) |
|
display_image = display_image_tensor.squeeze() |
|
|
|
|
|
transformed_image = DICOM_Utils.transform_image_for_display(display_image) |
|
windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) |
|
st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True) |
|
|
|
|
|
windowed_image = np.expand_dims(windowed_image, axis=2) |
|
windowed_image = np.tile(windowed_image, [1, 1, 3]) |
|
|
|
|
|
windowed_image = windowed_image.astype(np.float32) |
|
|
|
|
|
windowed_image = np.float32(windowed_image) / 255 |
|
|
|
|
|
target_layers = [mri_model.model.norm] |
|
cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True) |
|
grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) |
|
grayscale_cam = grayscale_cam[0, :] |
|
|
|
|
|
visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) |
|
st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True) |
|
|