import os import gradio as gr import torch from monai import bundle from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Orientationd, NormalizeIntensityd, Activationsd, AsDiscreted, ScaleIntensityd, ) # Define the bundle name and path for downloading BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0' BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME) # Title and description title = '

Segment Brain Tumors with MONAI! 🧠

' description = """ ## 🚀 To run Upload a brain MRI image file, or try out one of the examples below! If you want to see a different slice, update the slider. More details on the model can be found [here!](https://huggingface.co./katielink/brats_mri_segmentation_v0.1.0) ## ⚠️ Disclaimer This is an example, not to be used for diagnostic purposes. """ references = """ ## 👀 References 1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654. 2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694 3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117 """ examples = [ ['examples/BRATS_485.nii.gz', 65], ['examples/BRATS_486.nii.gz', 80] ] # Load the MONAI pretrained model from Hugging Face Hub model, _, _ = bundle.load( name = BUNDLE_NAME, source = 'huggingface_hub', repo = 'katielink/brats_mri_segmentation_v0.1.0', load_ts_module=True, ) # Use GPU if available device = "cuda:0" if torch.cuda.is_available() else "cpu" # Load the parser from the MONAI bundle's inference config parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json') # Compose the preprocessing transforms preproc_transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys="image"), Orientationd(keys=["image"], axcodes="RAS"), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), ] ) # Get the inferer from the bundle's inference config inferer = parser.get_parsed_content( 'inferer', lazy=True, eval_expr=True, instantiate=True ) # Compose the postprocessing transforms post_transforms = Compose( [ Activationsd(keys='pred', sigmoid=True), AsDiscreted(keys='pred', threshold=0.5), ScaleIntensityd(keys='image', minv=0., maxv=1.) ] ) # Define the predict function for the demo def predict(input_file, z_axis, model=model, device=device): # Load and process data in MONAI format data = {'image': [input_file.name]} data = preproc_transforms(data) # Run inference and post-process predicted labels model.to(device) model.eval() with torch.no_grad(): inputs = data['image'].to(device) data['pred'] = inferer(inputs=inputs[None,...], network=model) data = post_transforms(data) # Convert tensors back to numpy arrays data['image'] = data['image'].numpy() data['pred'] = data['pred'].cpu().detach().numpy() # Magnetic resonance imaging sequences t1c = data['image'][0, :, :, z_axis] # T1-weighted, post contrast t1 = data['image'][1, :, :, z_axis] # T1-weighted, pre contrast t2 = data['image'][2, :, :, z_axis] # T2-weighted flair = data['image'][3, :, :, z_axis] # FLAIR # BraTS labels tc = data['pred'][0, 0, :, :, z_axis] # Tumor core wt = data['pred'][0, 1, :, :, z_axis] # Whole tumor et = data['pred'][0, 2, :, :, z_axis] # Enhancing tumor return [t1c, t1, t2, flair], [tc, wt, et] # Use blocks to set up a more complex demo with gr.Blocks() as demo: # Show title and description gr.Markdown(title) gr.Markdown(description) with gr.Row(): # Get the input file and slice slider as inputs input_file = gr.File(label='input file') z_axis = gr.Slider(0, 200, label='slice', value=50) with gr.Row(): # Show the button with custom label button = gr.Button("Segment Tumor!") with gr.Row(): with gr.Column(): # Show the input image with different MR sequences input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)') with gr.Column(): # Show the segmentation labels output_segmentation = gr.Gallery(label='output segmentations (TC, WT, ET)') # Run prediction on button click button.click( predict, inputs=[input_file, z_axis], outputs=[input_image, output_segmentation] ) # Have some example for the user to try out examples = gr.Examples( examples=examples, inputs=[input_file, z_axis], outputs=[input_image, output_segmentation], fn=predict, cache_examples=False ) # Show references at the bottom of the demo gr.Markdown(references) # Launch the demo demo.launch()