File size: 6,821 Bytes
a903e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbf6114
1245f76
1f04f7b
a903e67
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# -*- coding: utf-8 -*-

import os
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
import pandas as pd
import gc
import random
import math
import glob
import torch
import gradio as gr
from PIL import Image
import cv2


classes = ['None','building','pervious surface','impervious surface','bare soil','water','coniferous','deciduous','brushwood','vineyard','herbaceous vegetation','agricultural land','plowed land']
id2label = pd.DataFrame(classes)[0].to_dict()
print(id2label)
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor

segformer_b0_rgb_model = SegformerForSemanticSegmentation.from_pretrained("alanoix/segformer_b0_flair_one",
                                                                            num_labels=len(id2label),
                                                                            id2label=id2label,
                                                                            label2id=label2id)

segformer_rgb_feature_extractor = SegformerFeatureExtractor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)
segformer_b0_rgb_model= torch.quantization.quantize_dynamic(segformer_b0_rgb_model, {torch.nn.Linear}, dtype=torch.qint8)


import albumentations as aug
MEAN = np.array([0.44050665, 0.45704361, 0.42254708])
STD = np.array([0.20264351, 0.1782405 , 0.17575739])

test_transform = aug.Compose([
    aug.Normalize(mean=MEAN, std=STD),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
segformer_b0_rgb_model = segformer_b0_rgb_model.to(device)

class_colors = [(random.randint(0, 255), random.randint(
    0, 255), random.randint(0, 255)) for _ in range(5000)]


# Default IMAGE_ORDERING = channels_last
IMAGE_ORDERING = "channels_last"


def get_colored_segmentation_image(seg_arr, n_classes, colors=class_colors):
    output_height = seg_arr.shape[0]
    output_width = seg_arr.shape[1]

    seg_img = np.zeros((output_height, output_width, 3))

    for c in range(n_classes):
        seg_arr_c = seg_arr[:, :] == c
        seg_img[:, :, 0] += ((seg_arr_c)*(colors[c][0])).astype('uint8')
        seg_img[:, :, 1] += ((seg_arr_c)*(colors[c][1])).astype('uint8')
        seg_img[:, :, 2] += ((seg_arr_c)*(colors[c][2])).astype('uint8')

    return seg_img


def get_legends(class_names, colors=class_colors):

    n_classes = len(class_names)
    legend = np.zeros(((len(class_names) * 25) + 25, 125, 3),
                      dtype="uint8") + 255

    class_names_colors = enumerate(zip(class_names[:n_classes],
                                       colors[:n_classes]))

    for (i, (class_name, color)) in class_names_colors:
        color = [int(c) for c in color]
        cv2.putText(legend, class_name, (5, (i * 25) + 17),
                    cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1)
        cv2.rectangle(legend, (100, (i * 25)), (125, (i * 25) + 25),
                      tuple(color), -1)

    return legend


def overlay_seg_image(inp_img, seg_img):
    orininal_h = inp_img.shape[0]
    orininal_w = inp_img.shape[1]
    seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), interpolation=cv2.INTER_NEAREST)

    fused_img = (inp_img/2 + seg_img/2).astype('uint8')
    return fused_img


def concat_lenends(seg_img, legend_img):

    new_h = np.maximum(seg_img.shape[0], legend_img.shape[0])
    new_w = seg_img.shape[1] + legend_img.shape[1]

    out_img = np.zeros((new_h, new_w, 3)).astype('uint8') + legend_img[0, 0, 0]

    out_img[:legend_img.shape[0], :  legend_img.shape[1]] = np.copy(legend_img)
    out_img[:seg_img.shape[0], legend_img.shape[1]:] = np.copy(seg_img)

    return out_img


def visualize_segmentation(seg_arr, inp_img=None, n_classes=None,
                           colors=class_colors, class_names=None,
                           overlay_img=False, show_legends=False,
                           prediction_width=None, prediction_height=None):

    if n_classes is None:
        n_classes = np.max(seg_arr)

    seg_img = get_colored_segmentation_image(seg_arr, n_classes, colors=colors)

    if inp_img is not None:
        original_h = inp_img.shape[0]
        original_w = inp_img.shape[1]
        seg_img = cv2.resize(seg_img, (original_w, original_h), interpolation=cv2.INTER_NEAREST)

    if (prediction_height is not None) and (prediction_width is not None):
        seg_img = cv2.resize(seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST)
        if inp_img is not None:
            inp_img = cv2.resize(inp_img,
                                 (prediction_width, prediction_height))

    if overlay_img:
        assert inp_img is not None
        seg_img = overlay_seg_image(inp_img, seg_img)

    if show_legends:
        assert class_names is not None
        legend_img = get_legends(class_names, colors=colors)

        seg_img = concat_lenends(seg_img, legend_img)

    return seg_img

def query_image(img):
    image_to_pred = test_transform(image=img)['image']

    pixel_values = segformer_rgb_feature_extractor(image_to_pred, return_tensors="pt").pixel_values.to(device)

    outputs_segformer_b0_rgb = segformer_b0_rgb_model(pixel_values=pixel_values)
    pred_segformer_b0_rgb = outputs_segformer_b0_rgb.logits.cpu().detach().numpy()

    pred = np.mean(np.array([K.softmax(pred_segformer_b0_rgb, axis = 1)]), axis = 0)
    pred = tf.image.resize(tf.transpose(pred, perm=[0,2,3,1]), size = [512,512], method="bilinear") # resize to 512*512
    pred = np.argmax(pred, axis = -1)
    pred =np.squeeze(pred)
    result = pred.astype(np.uint8)

    class_names = [ 'None', 'building', 'pervious surface', 'impervious surface', 'bare soil','water','coniferous','deciduous','brushwood','vineyard', 'herbaceous vegetation', 'agricultural land', 'plowed land']
    seg_img = visualize_segmentation(result, img, n_classes=13,
                                     colors=class_colors , overlay_img=True,
                                     show_legends=True,
                                     class_names=class_names,
                                     prediction_width=512,
                                     prediction_height=512)  
    
    return seg_img

demo = gr.Interface(
    
    query_image, 
    inputs=[gr.Image()], 
    outputs="image",
    title="Image Segmentation on aerial imagery",
    description = "Segformer b0 finetuned on IGN flair-one dataset, more information here : https://github.com/vlarmet/flair_ign_2nd_place ",
    examples=["IMG_011942.jpeg","IMG_005339.jpeg","IMG_004753.jpeg","IMG_011617.jpeg","IMG_003022.jpeg"]
)

demo.launch() #debug=True