harisankar95's picture
fixed predictions
4dc2bc3
"""
gradio app.py
for semantic segmentation
"""
import os
import cv2
import gradio as gr
import numpy as np
from otfgt import mask2sbd
def gen_sbd(image, mask):
h, w = image.shape[:2]
if w > 1280 or h > 720:
resize_factor = max(w / 1280, h / 720)
h = int(h / resize_factor)
w = int(w / resize_factor)
image = cv2.resize(image, (w, h))
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
binary_labels = np.zeros((19, mask.shape[0], mask.shape[1]), dtype=np.uint8)
unique_labels = np.unique(mask)
for label in unique_labels:
binary_labels[label] = mask == label
sbd = mask2sbd(binary_labels, ignore_indices=[])
# remove the first channel (background)
sbd = sbd[1:]
unique_boundary_labels = np.unique(np.where(sbd == 1)[0])
value = [(sbd[x], ID2LABEL[x]) for x in unique_boundary_labels]
# change 0 to 255
mask[mask == 0] = 255
# reduce the entries by 1
mask -= 1
unique_labels = np.unique(mask)
# remove 254 (background) from unique_labels
unique_labels = unique_labels[unique_labels != 254]
value_segmentation = [(mask == x, ID2LABEL[x]) for x in unique_labels]
return (image, value), (image, value_segmentation)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
ID2LABEL = { # id: label
0: "road",
1: "dirt",
2: "gravel",
3: "rock",
4: "grass",
5: "vegetation",
6: "tree",
7: "obstacle",
8: "animals",
9: "person",
10: "bicycle",
11: "vehicle",
12: "water",
13: "boat",
14: "building",
15: "roof",
16: "sky",
17: "drone",
}
input_1 = gr.Image(
image_mode="RGB",
type="numpy",
label="Image (RGB)",
)
input_2 = gr.Image(
image_mode="L",
type="numpy",
label="Segmentation Mask (Greyscale)",
)
INPUTS = [input_1, input_2]
output_1 = gr.AnnotatedImage(
label="Boundary Mask",
)
output_2 = gr.AnnotatedImage(
label="Segmentation Mask",
)
OUTPUTS = [output_1, output_2]
TITLE = "Semantic Boundary Generation"
DESCRIPTION = "Semantic Boundary Generation based on [paper](https://arxiv.org/pdf/2304.09427.pdf)."
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
cur_dir = os.path.dirname(os.path.abspath(__file__))
EXAMPLES = [
[
f"{cur_dir}/examples/aeroscape_1.jpg",
f"{cur_dir}/examples/aeroscape_1_mask.png",
],
[
f"{cur_dir}/examples/aeroscape_2.jpg",
f"{cur_dir}/examples/aeroscape_2_mask.png",
],
[
f"{cur_dir}/examples/floodnet_1.jpg",
f"{cur_dir}/examples/floodnet_1_mask.png",
],
[
f"{cur_dir}/examples/floodnet_2.jpg",
f"{cur_dir}/examples/floodnet_2_mask.png",
],
[
f"{cur_dir}/examples/floodnet_3.jpg",
f"{cur_dir}/examples/floodnet_3_mask.png",
],
[
f"{cur_dir}/examples/floodnet_4.jpg",
f"{cur_dir}/examples/floodnet_4_mask.png",
],
[
f"{cur_dir}/examples/floodnet_5.jpg",
f"{cur_dir}/examples/floodnet_5_mask.png",
],
[
f"{cur_dir}/examples/udd_1.jpg",
f"{cur_dir}/examples/udd_1_mask.png",
],
[
f"{cur_dir}/examples/udd_2.jpg",
f"{cur_dir}/examples/udd_2_mask.png",
],
]
demo = gr.Interface(
fn=gen_sbd,
inputs=INPUTS,
outputs=OUTPUTS,
title=TITLE,
description=DESCRIPTION,
live=False,
theme=theme,
allow_flagging="never",
cache_examples=True,
examples=EXAMPLES,
)
if __name__ == "__main__":
demo.launch()