import os import gradio as gr import torch from monai import bundle from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Orientationd, NormalizeIntensityd, Activationsd, AsDiscreted, ScaleIntensityd, ) BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0' BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME) title = "Segment Brain Tumors with MONAI!" description = """ ## Brain Tumor Segmentation 🧠 A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data. ## To run 🚀 Upload a image file in the format: 4 channel MRI (4 aligned MRIs T1c, T1, T2, FLAIR at 1x1x1 mm) ## Disclaimer ⚠️ This is an example, not to be used for diagnostic purposes. ## References 👀 1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654. 2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694 3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117 """ #examples = 'examples/' model, _, _ = bundle.load( name = BUNDLE_NAME, source = 'huggingface_hub', repo = 'katielink/brats_mri_segmentation_v0.1.0', load_ts_module=True, ) device = "cuda:0" if torch.cuda.is_available() else "cpu" parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json') preproc_transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys="image"), Orientationd(keys=["image"], axcodes="RAS"), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), ] ) inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True) post_transforms = Compose( [ Activationsd(keys='pred', sigmoid=True), AsDiscreted(keys='pred', threshold=0.5), ScaleIntensityd(keys='image', minv=0., maxv=1.) ] ) def predict(input_file, z_axis, model=model, device=device): data = {'image': [input_file.name]} data = preproc_transforms(data) model.to(device) model.eval() with torch.no_grad(): inputs = data['image'].to(device) data['pred'] = inferer(inputs=inputs[None,...], network=model) data = post_transforms(data) input_image = data['image'].numpy() pred_image = data['pred'].cpu().detach().numpy() input_t1c_image = input_image[0, :, :, z_axis] #input_t1_image = input_image[1, :, :, z_axis] #input_t2_image = input_image[2, :, :, z_axis] #input_flair_image = input_image[3, :, :, z_axis] pred_tc_image = pred_image[0, 0, :, :, z_axis] #pred_et_image = pred_image[0, 1, :, :, z_axis] #pred_wt_image = pred_image[0, 2, :, :, z_axis] return input_t1c_image, pred_tc_image, iface = gr.Interface( fn=predict, inputs=[ gr.File(label='Input file'), gr.Slider(0, 200, label='z-axis', value=100) ], outputs=[ gr.Image(label='T1C image'), gr.Image(label='Segmentation'), ], title=title, description=description, #examples=examples, ) iface.launch()