akhaliq HF staff commited on
Commit
7c8b533
·
1 Parent(s): 1f131cb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -0
app.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from collections import OrderedDict
7
+ import torch.nn.functional as F
8
+ from torch.utils import data
9
+ import torchvision.transforms as transform
10
+ from torch.nn.parallel.scatter_gather import gather
11
+ from additional_utils.models import LSeg_MultiEvalModule
12
+ from modules.lseg_module import LSegModule
13
+ import cv2
14
+ import math
15
+ import types
16
+ import functools
17
+ import torchvision.transforms as torch_transforms
18
+ import copy
19
+ import itertools
20
+ from PIL import Image
21
+ import matplotlib.pyplot as plt
22
+ import clip
23
+ from encoding.models.sseg import BaseNet
24
+ import matplotlib as mpl
25
+ import matplotlib.colors as mplc
26
+ import matplotlib.figure as mplfigure
27
+ import matplotlib.patches as mpatches
28
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
29
+ from data import get_dataset
30
+ import torchvision.transforms as transforms
31
+
32
+ import gradio as gr
33
+
34
+ model_name = "convnext_xlarge_in22k"
35
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
36
+ def get_new_pallete(num_cls):
37
+ n = num_cls
38
+ pallete = [0]*(n*3)
39
+ for j in range(0,n):
40
+ lab = j
41
+ pallete[j*3+0] = 0
42
+ pallete[j*3+1] = 0
43
+ pallete[j*3+2] = 0
44
+ i = 0
45
+ while (lab > 0):
46
+ pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
47
+ pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
48
+ pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
49
+ i = i + 1
50
+ lab >>= 3
51
+ return pallete
52
+
53
+ def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None):
54
+ """Get image color pallete for visualizing masks"""
55
+ # put colormap
56
+ out_img = Image.fromarray(npimg.squeeze().astype('uint8'))
57
+ out_img.putpalette(new_palette)
58
+
59
+ if out_label_flag:
60
+ assert labels is not None
61
+ u_index = np.unique(npimg)
62
+ patches = []
63
+ for i, index in enumerate(u_index):
64
+ label = labels[index]
65
+ cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0]
66
+ red_patch = mpatches.Patch(color=cur_color, label=label)
67
+ patches.append(red_patch)
68
+ return out_img, patches
69
+
70
+ @st.cache(allow_output_mutation=True)
71
+ def load_model():
72
+ class Options:
73
+ def __init__(self):
74
+ parser = argparse.ArgumentParser(description="PyTorch Segmentation")
75
+ # model and dataset
76
+ parser.add_argument(
77
+ "--model", type=str, default="encnet", help="model name (default: encnet)"
78
+ )
79
+ parser.add_argument(
80
+ "--backbone",
81
+ type=str,
82
+ default="clip_vitl16_384",
83
+ help="backbone name (default: resnet50)",
84
+ )
85
+ parser.add_argument(
86
+ "--dataset",
87
+ type=str,
88
+ default="ade20k",
89
+ help="dataset name (default: pascal12)",
90
+ )
91
+ parser.add_argument(
92
+ "--workers", type=int, default=16, metavar="N", help="dataloader threads"
93
+ )
94
+ parser.add_argument(
95
+ "--base-size", type=int, default=520, help="base image size"
96
+ )
97
+ parser.add_argument(
98
+ "--crop-size", type=int, default=480, help="crop image size"
99
+ )
100
+ parser.add_argument(
101
+ "--train-split",
102
+ type=str,
103
+ default="train",
104
+ help="dataset train split (default: train)",
105
+ )
106
+ parser.add_argument(
107
+ "--aux", action="store_true", default=False, help="Auxilary Loss"
108
+ )
109
+ parser.add_argument(
110
+ "--se-loss",
111
+ action="store_true",
112
+ default=False,
113
+ help="Semantic Encoding Loss SE-loss",
114
+ )
115
+ parser.add_argument(
116
+ "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
117
+ )
118
+ parser.add_argument(
119
+ "--batch-size",
120
+ type=int,
121
+ default=16,
122
+ metavar="N",
123
+ help="input batch size for \
124
+ training (default: auto)",
125
+ )
126
+ parser.add_argument(
127
+ "--test-batch-size",
128
+ type=int,
129
+ default=16,
130
+ metavar="N",
131
+ help="input batch size for \
132
+ testing (default: same as batch size)",
133
+ )
134
+ # cuda, seed and logging
135
+ parser.add_argument(
136
+ "--no-cuda",
137
+ action="store_true",
138
+ default=False,
139
+ help="disables CUDA training",
140
+ )
141
+ parser.add_argument(
142
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
143
+ )
144
+ # checking point
145
+ parser.add_argument(
146
+ "--weights", type=str, default='', help="checkpoint to test"
147
+ )
148
+ # evaluation option
149
+ parser.add_argument(
150
+ "--eval", action="store_true", default=False, help="evaluating mIoU"
151
+ )
152
+ parser.add_argument(
153
+ "--export",
154
+ type=str,
155
+ default=None,
156
+ help="put the path to resuming file if needed",
157
+ )
158
+ parser.add_argument(
159
+ "--acc-bn",
160
+ action="store_true",
161
+ default=False,
162
+ help="Re-accumulate BN statistics",
163
+ )
164
+ parser.add_argument(
165
+ "--test-val",
166
+ action="store_true",
167
+ default=False,
168
+ help="generate masks on val set",
169
+ )
170
+ parser.add_argument(
171
+ "--no-val",
172
+ action="store_true",
173
+ default=False,
174
+ help="skip validation during training",
175
+ )
176
+
177
+ parser.add_argument(
178
+ "--module",
179
+ default='lseg',
180
+ help="select model definition",
181
+ )
182
+
183
+ # test option
184
+ parser.add_argument(
185
+ "--data-path", type=str, default='../datasets/', help="path to test image folder"
186
+ )
187
+
188
+ parser.add_argument(
189
+ "--no-scaleinv",
190
+ dest="scale_inv",
191
+ default=True,
192
+ action="store_false",
193
+ help="turn off scaleinv layers",
194
+ )
195
+
196
+ parser.add_argument(
197
+ "--widehead", default=False, action="store_true", help="wider output head"
198
+ )
199
+
200
+ parser.add_argument(
201
+ "--widehead_hr",
202
+ default=False,
203
+ action="store_true",
204
+ help="wider output head",
205
+ )
206
+ parser.add_argument(
207
+ "--ignore_index",
208
+ type=int,
209
+ default=-1,
210
+ help="numeric value of ignore label in gt",
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--label_src",
215
+ type=str,
216
+ default="default",
217
+ help="how to get the labels",
218
+ )
219
+
220
+ parser.add_argument(
221
+ "--arch_option",
222
+ type=int,
223
+ default=0,
224
+ help="which kind of architecture to be used",
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--block_depth",
229
+ type=int,
230
+ default=0,
231
+ help="how many blocks should be used",
232
+ )
233
+
234
+ parser.add_argument(
235
+ "--activation",
236
+ choices=['lrelu', 'tanh'],
237
+ default="lrelu",
238
+ help="use which activation to activate the block",
239
+ )
240
+
241
+ self.parser = parser
242
+
243
+ def parse(self):
244
+ args = self.parser.parse_args(args=[])
245
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
246
+ print(args)
247
+ return args
248
+
249
+ args = Options().parse()
250
+
251
+ torch.manual_seed(args.seed)
252
+ args.test_batch_size = 1
253
+ alpha=0.5
254
+
255
+ args.scale_inv = False
256
+ args.widehead = True
257
+ args.dataset = 'ade20k'
258
+ args.backbone = 'clip_vitl16_384'
259
+ args.weights = 'checkpoints/demo_e200.ckpt'
260
+ args.ignore_index = 255
261
+
262
+ module = LSegModule.load_from_checkpoint(
263
+ checkpoint_path=args.weights,
264
+ data_path=args.data_path,
265
+ dataset=args.dataset,
266
+ backbone=args.backbone,
267
+ aux=args.aux,
268
+ num_features=256,
269
+ aux_weight=0,
270
+ se_loss=False,
271
+ se_weight=0,
272
+ base_lr=0,
273
+ batch_size=1,
274
+ max_epochs=0,
275
+ ignore_index=args.ignore_index,
276
+ dropout=0.0,
277
+ scale_inv=args.scale_inv,
278
+ augment=False,
279
+ no_batchnorm=False,
280
+ widehead=args.widehead,
281
+ widehead_hr=args.widehead_hr,
282
+ map_locatin="cpu",
283
+ arch_option=0,
284
+ block_depth=0,
285
+ activation='lrelu',
286
+ )
287
+
288
+ input_transform = module.val_transform
289
+
290
+ # dataloader
291
+ loader_kwargs = (
292
+ {"num_workers": args.workers, "pin_memory": True} if args.cuda else {}
293
+ )
294
+
295
+ # model
296
+ if isinstance(module.net, BaseNet):
297
+ model = module.net
298
+ else:
299
+ model = module
300
+
301
+ model = model.eval()
302
+ model = model.cpu()
303
+ scales = (
304
+ [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25]
305
+ if args.dataset == "citys"
306
+ else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
307
+ )
308
+
309
+ model.mean = [0.5, 0.5, 0.5]
310
+ model.std = [0.5, 0.5, 0.5]
311
+ evaluator = LSeg_MultiEvalModule(
312
+ model, scales=scales, flip=True
313
+ ).cuda()
314
+ evaluator.eval()
315
+
316
+ transform = transforms.Compose(
317
+ [
318
+ transforms.ToTensor(),
319
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
320
+ transforms.Resize([360,480]),
321
+ ]
322
+ )
323
+
324
+ return evaluator, transform
325
+
326
+ """
327
+ # LSeg Demo
328
+ """
329
+ lseg_model, lseg_transform = load_model()
330
+
331
+ # to be revised
332
+ uploaded_file = gr.inputs.Image(type='pil')
333
+ input_labels = st.text_input("Input labels", value="dog, grass, other")
334
+ gr.outputs.Label(type="confidences",num_top_classes=5)
335
+ st.write("The labels are", input_labels)
336
+
337
+ image = Image.open(uploaded_file)
338
+ pimage = lseg_transform(np.array(image)).unsqueeze(0)
339
+
340
+ labels = []
341
+ for label in input_labels.split(","):
342
+ labels.append(label.strip())
343
+
344
+ with torch.no_grad():
345
+ outputs = lseg_model.parallel_forward(pimage, labels)
346
+
347
+ predicts = [
348
+ torch.max(output, 1)[1].cpu().numpy()
349
+ for output in outputs
350
+ ]
351
+
352
+ image = pimage[0].permute(1,2,0)
353
+ image = image * 0.5 + 0.5
354
+ image = Image.fromarray(np.uint8(255*image)).convert("RGBA")
355
+
356
+ pred = predicts[0]
357
+ new_palette = get_new_pallete(len(labels))
358
+ mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels)
359
+ seg = mask.convert("RGBA")
360
+
361
+ fig = plt.figure()
362
+ plt.subplot(121)
363
+ plt.imshow(image)
364
+ plt.axis('off')
365
+
366
+ plt.subplot(122)
367
+ plt.imshow(seg)
368
+ plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5})
369
+ plt.axis('off')
370
+
371
+ plt.tight_layout()
372
+
373
+ #st.image([image,seg], width=700, caption=["Input image", "Segmentation"])
374
+ st.pyplot(fig)
375
+
376
+ title = "LSeg"
377
+
378
+ description = "Gradio demo for LSeg for semantic segmentation. To use it, simply upload your image, or click one of the examples to load them, then add any label set"
379
+
380
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.03546' target='_blank'>Language-driven Semantic Segmentation</a> | <a href='hhttps://github.com/isl-org/lang-seg' target='_blank'>Github Repo</a></p>"
381
+
382
+ examples = ['test.jpeg']
383
+
384
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(enable_queue=True)