import sys import os import glob import shutil import torch import argparse import mediapy import cv2 import numpy as np import gradio as gr from skimage import color, img_as_ubyte from monai import transforms, data os.system("git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc") sys.path.append("pmrc/SwinUNETR/BTCV") from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig ffmpeg_path = shutil.which('ffmpeg') mediapy.set_ffmpeg(ffmpeg_path) # Load model model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny') model.eval() # Pull files from github input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz') input_files = dict((f.split('/')[-1], f) for f in input_files) # Load and process dicom with monai transforms test_transform = transforms.Compose( [ transforms.LoadImaged(keys=["image"]), transforms.AddChanneld(keys=["image"]), transforms.Spacingd(keys="image", pixdim=(1.5, 1.5, 2.0), mode="bilinear"), transforms.ScaleIntensityRanged(keys=["image"], a_min=-175.0, a_max=250.0, b_min=0.0, b_max=1.0, clip=True), # transforms.Resized(keys=["image"], spatial_size = (256,256,-1)), transforms.ToTensord(keys=["image"]), ]) # Create Data Loader def create_dl(test_files): ds = test_transform(test_files) loader = data.DataLoader(ds, batch_size=1, shuffle=False) return loader # Inference and video generation def generate_dicom_video(selected_file, n_frames): # Data processor test_file = input_files[selected_file] test_files = [{'image': test_file}] dl = create_dl(test_files) batch = next(iter(dl)) # Select dicom slices tst_inputs = batch["image"] tst_inputs = tst_inputs[:,:,:,:,-n_frames:] # Inference with torch.no_grad(): outputs = model(tst_inputs, (96,96,96), 8, overlap=0.5, mode="gaussian") tst_outputs = torch.softmax(outputs.logits, 1) tst_outputs = torch.argmax(tst_outputs, axis=1) # Write frames to video for inp, outp in zip(tst_inputs, tst_outputs): frames = [] for idx in range(inp.shape[-1]): # Segmentation seg = outp[:,:,idx].numpy().astype(np.uint8) # Input dicom frame img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8) img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) frame = color.label2rgb(seg,img, bg_label = 0) frame = img_as_ubyte(frame) frame = np.concatenate((img, frame), 1) frames.append(frame) mediapy.write_video("dicom.mp4", frames, fps=4) return 'dicom.mp4' theme = 'dark-peach' with gr.Blocks(theme=theme) as demo: gr.Markdown('''