katielink's picture
Update description in app.py and try examples again
a2cbc95
raw
history blame
3.45 kB
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
Orientationd,
NormalizeIntensityd,
Activationsd,
AsDiscreted,
ScaleIntensityd,
)
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
title = "Segment Brain Tumors with MONAI!"
description = """
## Brain Tumor Segmentation 🧠
A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data.
## To run πŸš€
Upload a image file in the format: 4 channel MRI (4 aligned MRIs T1c, T1, T2, FLAIR at 1x1x1 mm)
## Disclaimer ⚠️
This is an example, not to be used for diagnostic purposes.
## References πŸ‘€
1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694
3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
"""
#examples = 'examples/'
model, _, _ = bundle.load(
name = BUNDLE_NAME,
source = 'huggingface_hub',
repo = 'katielink/brats_mri_segmentation_v0.1.0',
load_ts_module=True,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True)
post_transforms = Compose(
[
Activationsd(keys='pred', sigmoid=True),
AsDiscreted(keys='pred', threshold=0.5),
ScaleIntensityd(keys='image', minv=0., maxv=1.)
]
)
def predict(input_file, z_axis, model=model, device=device):
data = {'image': [input_file.name]}
data = preproc_transforms(data)
model.to(device)
model.eval()
with torch.no_grad():
inputs = data['image'].to(device)
data['pred'] = inferer(inputs=inputs[None,...], network=model)
data = post_transforms(data)
input_image = data['image'].numpy()
pred_image = data['pred'].cpu().detach().numpy()
input_t1c_image = input_image[0, :, :, z_axis]
#input_t1_image = input_image[1, :, :, z_axis]
#input_t2_image = input_image[2, :, :, z_axis]
#input_flair_image = input_image[3, :, :, z_axis]
pred_tc_image = pred_image[0, 0, :, :, z_axis]
#pred_et_image = pred_image[0, 1, :, :, z_axis]
#pred_wt_image = pred_image[0, 2, :, :, z_axis]
return input_t1c_image, pred_tc_image,
iface = gr.Interface(
fn=predict,
inputs=[
gr.File(label='Input file'),
gr.Slider(0, 200, label='z-axis', value=100)
],
outputs=[
gr.Image(label='T1C image'),
gr.Image(label='Segmentation'),
],
title=title,
description=description,
#examples=examples,
)
iface.launch()