Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import torch | |
from monai import bundle | |
from monai.transforms import ( | |
Compose, | |
LoadImaged, | |
EnsureChannelFirstd, | |
Orientationd, | |
NormalizeIntensityd, | |
Activationsd, | |
AsDiscreted, | |
ScaleIntensityd, | |
) | |
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0' | |
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME) | |
title = "Segment Brain Tumors with MONAI!" | |
description = """ | |
## Brain Tumor Segmentation π§ | |
A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data. | |
## To run π | |
Upload a image file in the format: 4 channel MRI (4 aligned MRIs T1c, T1, T2, FLAIR at 1x1x1 mm) | |
## Disclaimer β οΈ | |
This is an example, not to be used for diagnostic purposes. | |
## References π | |
1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654. | |
2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694 | |
3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117 | |
""" | |
#examples = 'examples/' | |
model, _, _ = bundle.load( | |
name = BUNDLE_NAME, | |
source = 'huggingface_hub', | |
repo = 'katielink/brats_mri_segmentation_v0.1.0', | |
load_ts_module=True, | |
) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json') | |
preproc_transforms = Compose( | |
[ | |
LoadImaged(keys=["image"]), | |
EnsureChannelFirstd(keys="image"), | |
Orientationd(keys=["image"], axcodes="RAS"), | |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), | |
] | |
) | |
inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True) | |
post_transforms = Compose( | |
[ | |
Activationsd(keys='pred', sigmoid=True), | |
AsDiscreted(keys='pred', threshold=0.5), | |
ScaleIntensityd(keys='image', minv=0., maxv=1.) | |
] | |
) | |
def predict(input_file, z_axis, model=model, device=device): | |
data = {'image': [input_file.name]} | |
data = preproc_transforms(data) | |
model.to(device) | |
model.eval() | |
with torch.no_grad(): | |
inputs = data['image'].to(device) | |
data['pred'] = inferer(inputs=inputs[None,...], network=model) | |
data = post_transforms(data) | |
input_image = data['image'].numpy() | |
pred_image = data['pred'].cpu().detach().numpy() | |
input_t1c_image = input_image[0, :, :, z_axis] | |
#input_t1_image = input_image[1, :, :, z_axis] | |
#input_t2_image = input_image[2, :, :, z_axis] | |
#input_flair_image = input_image[3, :, :, z_axis] | |
pred_tc_image = pred_image[0, 0, :, :, z_axis] | |
#pred_et_image = pred_image[0, 1, :, :, z_axis] | |
#pred_wt_image = pred_image[0, 2, :, :, z_axis] | |
return input_t1c_image, pred_tc_image, | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.File(label='Input file'), | |
gr.Slider(0, 200, label='z-axis', value=100) | |
], | |
outputs=[ | |
gr.Image(label='T1C image'), | |
gr.Image(label='Segmentation'), | |
], | |
title=title, | |
description=description, | |
#examples=examples, | |
) | |
iface.launch() | |