""" gradio app.py for semantic segmentation """ import os import cv2 import gradio as gr import numpy as np from otfgt import mask2sbd def gen_sbd(image, mask): h, w = image.shape[:2] if w > 1280 or h > 720: resize_factor = max(w / 1280, h / 720) h = int(h / resize_factor) w = int(w / resize_factor) image = cv2.resize(image, (w, h)) mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) binary_labels = np.zeros((19, mask.shape[0], mask.shape[1]), dtype=np.uint8) unique_labels = np.unique(mask) for label in unique_labels: binary_labels[label] = mask == label sbd = mask2sbd(binary_labels, ignore_indices=[]) # remove the first channel (background) sbd = sbd[1:] unique_boundary_labels = np.unique(np.where(sbd == 1)[0]) value = [(sbd[x], ID2LABEL[x]) for x in unique_boundary_labels] # change 0 to 255 mask[mask == 0] = 255 # reduce the entries by 1 mask -= 1 unique_labels = np.unique(mask) # remove 254 (background) from unique_labels unique_labels = unique_labels[unique_labels != 254] value_segmentation = [(mask == x, ID2LABEL[x]) for x in unique_labels] return (image, value), (image, value_segmentation) HF_TOKEN = os.environ.get("HF_TOKEN", None) ID2LABEL = { # id: label 0: "road", 1: "dirt", 2: "gravel", 3: "rock", 4: "grass", 5: "vegetation", 6: "tree", 7: "obstacle", 8: "animals", 9: "person", 10: "bicycle", 11: "vehicle", 12: "water", 13: "boat", 14: "building", 15: "roof", 16: "sky", 17: "drone", } input_1 = gr.Image( image_mode="RGB", type="numpy", label="Image (RGB)", ) input_2 = gr.Image( image_mode="L", type="numpy", label="Segmentation Mask (Greyscale)", ) INPUTS = [input_1, input_2] output_1 = gr.AnnotatedImage( label="Boundary Mask", ) output_2 = gr.AnnotatedImage( label="Segmentation Mask", ) OUTPUTS = [output_1, output_2] TITLE = "Semantic Boundary Generation" DESCRIPTION = "Semantic Boundary Generation based on [paper](https://arxiv.org/pdf/2304.09427.pdf)." theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[ gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif", ], ) cur_dir = os.path.dirname(os.path.abspath(__file__)) EXAMPLES = [ [ f"{cur_dir}/examples/aeroscape_1.jpg", f"{cur_dir}/examples/aeroscape_1_mask.png", ], [ f"{cur_dir}/examples/aeroscape_2.jpg", f"{cur_dir}/examples/aeroscape_2_mask.png", ], [ f"{cur_dir}/examples/floodnet_1.jpg", f"{cur_dir}/examples/floodnet_1_mask.png", ], [ f"{cur_dir}/examples/floodnet_2.jpg", f"{cur_dir}/examples/floodnet_2_mask.png", ], [ f"{cur_dir}/examples/floodnet_3.jpg", f"{cur_dir}/examples/floodnet_3_mask.png", ], [ f"{cur_dir}/examples/floodnet_4.jpg", f"{cur_dir}/examples/floodnet_4_mask.png", ], [ f"{cur_dir}/examples/floodnet_5.jpg", f"{cur_dir}/examples/floodnet_5_mask.png", ], [ f"{cur_dir}/examples/udd_1.jpg", f"{cur_dir}/examples/udd_1_mask.png", ], [ f"{cur_dir}/examples/udd_2.jpg", f"{cur_dir}/examples/udd_2_mask.png", ], ] demo = gr.Interface( fn=gen_sbd, inputs=INPUTS, outputs=OUTPUTS, title=TITLE, description=DESCRIPTION, live=False, theme=theme, allow_flagging="never", cache_examples=True, examples=EXAMPLES, ) if __name__ == "__main__": demo.launch()