# -*- coding: utf-8 -*- import os import numpy as np from glob import glob import matplotlib.pyplot as plt import matplotlib import tensorflow as tf from tensorflow import keras from tensorflow.keras import backend as K import pandas as pd import gc import random import math import glob import torch import gradio as gr from PIL import Image import cv2 classes = ['None','building','pervious surface','impervious surface','bare soil','water','coniferous','deciduous','brushwood','vineyard','herbaceous vegetation','agricultural land','plowed land'] id2label = pd.DataFrame(classes)[0].to_dict() print(id2label) label2id = {v: k for k, v in id2label.items()} num_labels = len(id2label) from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor segformer_b0_rgb_model = SegformerForSemanticSegmentation.from_pretrained("alanoix/segformer_b0_flair_one", num_labels=len(id2label), id2label=id2label, label2id=label2id) segformer_rgb_feature_extractor = SegformerFeatureExtractor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False) segformer_b0_rgb_model= torch.quantization.quantize_dynamic(segformer_b0_rgb_model, {torch.nn.Linear}, dtype=torch.qint8) import albumentations as aug MEAN = np.array([0.44050665, 0.45704361, 0.42254708]) STD = np.array([0.20264351, 0.1782405 , 0.17575739]) test_transform = aug.Compose([ aug.Normalize(mean=MEAN, std=STD), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") segformer_b0_rgb_model = segformer_b0_rgb_model.to(device) class_colors = [(random.randint(0, 255), random.randint( 0, 255), random.randint(0, 255)) for _ in range(5000)] # Default IMAGE_ORDERING = channels_last IMAGE_ORDERING = "channels_last" def get_colored_segmentation_image(seg_arr, n_classes, colors=class_colors): output_height = seg_arr.shape[0] output_width = seg_arr.shape[1] seg_img = np.zeros((output_height, output_width, 3)) for c in range(n_classes): seg_arr_c = seg_arr[:, :] == c seg_img[:, :, 0] += ((seg_arr_c)*(colors[c][0])).astype('uint8') seg_img[:, :, 1] += ((seg_arr_c)*(colors[c][1])).astype('uint8') seg_img[:, :, 2] += ((seg_arr_c)*(colors[c][2])).astype('uint8') return seg_img def get_legends(class_names, colors=class_colors): n_classes = len(class_names) legend = np.zeros(((len(class_names) * 25) + 25, 125, 3), dtype="uint8") + 255 class_names_colors = enumerate(zip(class_names[:n_classes], colors[:n_classes])) for (i, (class_name, color)) in class_names_colors: color = [int(c) for c in color] cv2.putText(legend, class_name, (5, (i * 25) + 17), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1) cv2.rectangle(legend, (100, (i * 25)), (125, (i * 25) + 25), tuple(color), -1) return legend def overlay_seg_image(inp_img, seg_img): orininal_h = inp_img.shape[0] orininal_w = inp_img.shape[1] seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), interpolation=cv2.INTER_NEAREST) fused_img = (inp_img/2 + seg_img/2).astype('uint8') return fused_img def concat_lenends(seg_img, legend_img): new_h = np.maximum(seg_img.shape[0], legend_img.shape[0]) new_w = seg_img.shape[1] + legend_img.shape[1] out_img = np.zeros((new_h, new_w, 3)).astype('uint8') + legend_img[0, 0, 0] out_img[:legend_img.shape[0], : legend_img.shape[1]] = np.copy(legend_img) out_img[:seg_img.shape[0], legend_img.shape[1]:] = np.copy(seg_img) return out_img def visualize_segmentation(seg_arr, inp_img=None, n_classes=None, colors=class_colors, class_names=None, overlay_img=False, show_legends=False, prediction_width=None, prediction_height=None): if n_classes is None: n_classes = np.max(seg_arr) seg_img = get_colored_segmentation_image(seg_arr, n_classes, colors=colors) if inp_img is not None: original_h = inp_img.shape[0] original_w = inp_img.shape[1] seg_img = cv2.resize(seg_img, (original_w, original_h), interpolation=cv2.INTER_NEAREST) if (prediction_height is not None) and (prediction_width is not None): seg_img = cv2.resize(seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST) if inp_img is not None: inp_img = cv2.resize(inp_img, (prediction_width, prediction_height)) if overlay_img: assert inp_img is not None seg_img = overlay_seg_image(inp_img, seg_img) if show_legends: assert class_names is not None legend_img = get_legends(class_names, colors=colors) seg_img = concat_lenends(seg_img, legend_img) return seg_img def query_image(img): image_to_pred = test_transform(image=img)['image'] pixel_values = segformer_rgb_feature_extractor(image_to_pred, return_tensors="pt").pixel_values.to(device) outputs_segformer_b0_rgb = segformer_b0_rgb_model(pixel_values=pixel_values) pred_segformer_b0_rgb = outputs_segformer_b0_rgb.logits.cpu().detach().numpy() pred = np.mean(np.array([K.softmax(pred_segformer_b0_rgb, axis = 1)]), axis = 0) pred = tf.image.resize(tf.transpose(pred, perm=[0,2,3,1]), size = [512,512], method="bilinear") # resize to 512*512 pred = np.argmax(pred, axis = -1) pred =np.squeeze(pred) result = pred.astype(np.uint8) class_names = [ 'None', 'building', 'pervious surface', 'impervious surface', 'bare soil','water','coniferous','deciduous','brushwood','vineyard', 'herbaceous vegetation', 'agricultural land', 'plowed land'] seg_img = visualize_segmentation(result, img, n_classes=13, colors=class_colors , overlay_img=True, show_legends=True, class_names=class_names, prediction_width=512, prediction_height=512) return seg_img demo = gr.Interface( query_image, inputs=[gr.Image()], outputs="image", title="Image Segmentation on aerial imagery", description = "model finetuned on IGN flair-one dataset", examples=["IMG_011942.jpeg","IMG_005339.jpeg","IMG_004753.jpeg","IMG_011617.jpeg","IMG_003022.jpeg"] ) demo.launch() #debug=True