ahmedbrs commited on
Commit
2790d30
·
verified ·
1 Parent(s): 91c3039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +510 -510
app.py CHANGED
@@ -1,510 +1,510 @@
1
- import gradio as gr
2
- from PIL import Image
3
- import torch
4
- from torchvision.transforms import InterpolationMode
5
-
6
- BICUBIC = InterpolationMode.BICUBIC
7
- from utils import setup, get_similarity_map,get_noun_phrase, rgb_to_hsv, hsv_to_rgb
8
- from vpt.launch import default_argument_parser
9
- from collections import OrderedDict
10
- import numpy as np
11
- import matplotlib.pyplot as plt
12
- import models
13
- import string
14
- import nltk
15
- nltk.download('punkt')
16
- nltk.download('averaged_perceptron_tagger')
17
- from nltk.tokenize import word_tokenize
18
- import torchvision
19
-
20
- args = default_argument_parser().parse_args()
21
- cfg = setup(args)
22
-
23
- multi_classes = False
24
-
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
- Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
27
- state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)
28
-
29
- # Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
30
- new_state_dict = OrderedDict()
31
- for k, v in state_dict.items():
32
- name = k[7:] # remove `module.`
33
- new_state_dict[name] = v
34
- Ours.load_state_dict(new_state_dict)
35
- Ours.eval()
36
- print("Model loaded successfully")
37
-
38
-
39
- def run(sketch, caption, threshold, seed):
40
- # select a random seed between 1 and 10 for the color
41
- color_seed = np.random.randint(0, 9)
42
-
43
- # set the condidate classes here
44
- caption = caption.replace('\n',' ')
45
- translator = str.maketrans('', '', string.punctuation)
46
- caption = caption.translate(translator).lower()
47
- words = word_tokenize(caption)
48
- classes = get_noun_phrase(words)
49
- if len(classes) ==0 or multi_classes == False:
50
- classes = [caption]
51
-
52
- # print(classes)
53
-
54
- colors = plt.get_cmap("Set1").colors
55
- classes_colors = colors[color_seed:len(classes)+color_seed]
56
-
57
- sketch2 = sketch['composite']
58
-
59
- # when the drawing tool is used
60
- if sketch2[:,:,0:3].sum() == 0:
61
- temp = sketch2[:,:,3]
62
- # invert it
63
- temp = 255 - temp
64
- sketch2 = np.repeat(temp[:, :, np.newaxis], 3, axis=2)
65
- temp2= np.full_like(temp, 255)
66
- sketch2 = np.dstack((sketch2, temp2))
67
-
68
- sketch2 = np.array(sketch2)
69
- pil_img = Image.fromarray(sketch2).convert('RGB')
70
- sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)
71
- # torchvision.utils.save_image(sketch_tensor, 'sketch_tensor.png')
72
-
73
- with torch.no_grad():
74
- text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device, no_module=True)
75
- redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device, no_module=True)
76
-
77
- num_of_tokens = 3
78
- with torch.no_grad():
79
- sketch_features = Ours.encode_image(sketch_tensor, layers=[12],
80
- text_features=text_features - redundant_features, mode="test").squeeze(0)
81
- sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
82
- similarity = sketch_features @ (text_features - redundant_features).t()
83
- patches_similarity = similarity[0, num_of_tokens + 1:, :]
84
- pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0), pil_img.size).cpu()
85
- # visualize_attention_maps_with_tokens(pixel_similarity, classes)
86
- pixel_similarity[pixel_similarity < threshold] = 0
87
- pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1)
88
-
89
-
90
- # display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True)
91
-
92
- # Find the class index with the highest similarity for each pixel
93
- class_indices = np.argmax(pixel_similarity_array, axis=0)
94
- # Create an HSV image placeholder
95
- hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3)
96
- hsv_image[..., 2] = 1 # Set Value to 1 for a white base
97
-
98
- # Set the hue and value channels
99
- for i, color in enumerate(classes_colors):
100
- rgb_color = np.array(color).reshape(1, 1, 3)
101
- hsv_color = rgb_to_hsv(rgb_color)
102
- mask = class_indices == i
103
- if i < len(classes): # For the first N-2 classes, set color based on similarity
104
- hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue
105
- hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation
106
- hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value
107
- else: # For the last two classes, set pixels to black
108
- hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black
109
- hsv_image[..., 1][mask] = 0 # Saturation set to 0
110
- hsv_image[..., 2][mask] = 0 # Value set to 0, making it black
111
-
112
- mask_tensor_org = sketch2[:,:,0]/255
113
- hsv_image[mask_tensor_org==1] = [0,0,1]
114
-
115
- # Convert the HSV image back to RGB to display and save
116
- rgb_image = hsv_to_rgb(hsv_image)
117
-
118
-
119
- if len(classes) > 1:
120
- # Calculate centroids and render class names
121
- for i, class_name in enumerate(classes):
122
- mask = class_indices == i
123
- if np.any(mask):
124
- y, x = np.nonzero(mask)
125
- centroid_x, centroid_y = np.mean(x), np.mean(y)
126
- plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i]
127
- bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
128
-
129
- # Display the image with class names
130
- plt.imshow(rgb_image)
131
- plt.axis('off')
132
- plt.tight_layout()
133
- # plt.savefig(f'poster_vis/{classes[0]}.png', bbox_inches='tight', pad_inches=0)
134
- plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
135
- plt.close()
136
-
137
- # rgb_image = Image.open(f'poster_vis/{classes[0]}.png')
138
- rgb_image = Image.open('output.png')
139
-
140
- return rgb_image
141
-
142
-
143
-
144
- scripts = """
145
- async () => {
146
- // START gallery format
147
- // Get all image elements with the class "image"
148
- var images = document.querySelectorAll('.image_gallery');
149
- var originalParent = document.querySelector('#component-0');
150
- // Create a new parent div element
151
- var parentDiv = document.createElement('div');
152
- var beforeDiv= document.querySelector('.table-wrap').parentElement;
153
- parentDiv.id = "gallery_container";
154
-
155
- // Loop through each image, append it to the parent div, and remove it from its original parent
156
- images.forEach(function(image , index ) {
157
- // Append the image to the parent div
158
- parentDiv.appendChild(image);
159
-
160
- // Add click event listener to each image
161
- image.addEventListener('click', function() {
162
- let nth_ch = index+1
163
- document.querySelector('.tr-body:nth-child(' + nth_ch + ')').click()
164
- console.log('.tr-body:nth-child(' + nth_ch + ')');
165
- });
166
-
167
- // Remove the image from its original parent
168
- });
169
-
170
-
171
- // Get a reference to the original parent of the images
172
- var originalParent = document.querySelector('#component-0');
173
-
174
- // Append the new parent div to the original parent
175
- originalParent.insertBefore(parentDiv, beforeDiv);
176
-
177
- // END gallery format
178
-
179
- // START confidence span
180
-
181
- // Get the selected div (replace 'selectedDivId' with the actual ID of your div)
182
- var selectedDiv = document.querySelector("label[for='range_id_0'] > span")
183
-
184
- // Get the text content of the div
185
- var textContent = selectedDiv.textContent;
186
-
187
- // Find the text before the first colon ':'
188
- var colonIndex = textContent.indexOf(':');
189
- var textBeforeColon = textContent.substring(0, colonIndex);
190
-
191
- // Wrap the text before colon with a span element
192
- var spanElement = document.createElement('span');
193
- spanElement.textContent = textBeforeColon;
194
-
195
- // Replace the original text with the modified text containing the span
196
- selectedDiv.innerHTML = textContent.replace(textBeforeColon, spanElement.outerHTML);
197
-
198
- // START format the column names :
199
- // Get all elements with the class "test_class"
200
- var elements = document.querySelectorAll('.tr-head > th');
201
-
202
- // Iterate over each element
203
- elements.forEach(function(element) {
204
- // Get the text content of the element
205
- var text = element.textContent.trim();
206
-
207
- // Remove ":" from the text
208
- var wordWithoutColon = text.replace(':', '');
209
-
210
- // Split the text into words
211
- var words = wordWithoutColon.split(' ');
212
-
213
- // Keep only the first word
214
- var firstWord = words[0];
215
-
216
- // Set the text content of the element to the first word
217
- element.textContent = firstWord;
218
- });
219
-
220
- document.querySelector('input[type=number]').disabled = true;
221
- }
222
- """
223
-
224
- css="""
225
-
226
- gradio-app {
227
- background-color: white !important;
228
- }
229
-
230
- .white-bg {
231
- background-color: white !important;
232
- }
233
-
234
- .gray-border {
235
- border: 1px solid dimgrey !important;
236
- }
237
-
238
- .border-radius {
239
- border-radius: 8px !important;
240
- }
241
-
242
- .black-text {
243
- color : black !important;
244
- }
245
-
246
- th {
247
- color : black !important;
248
-
249
- }
250
-
251
- tr {
252
- background-color: white !important;
253
- color: black !important;
254
- }
255
-
256
- td {
257
- border-bottom : 1px solid black !important;
258
- }
259
-
260
- label[data-testid="block-label"] {
261
- background: white;
262
- color: black;
263
- font-weight: bold;
264
- }
265
-
266
- .controls-wrap button:disabled {
267
- color: gray !important;
268
- background-color: white !important;
269
- }
270
-
271
- .controls-wrap button:not(:disabled) {
272
- color: black !important;
273
- background-color: white !important;
274
-
275
- }
276
-
277
- .source-wrap button {
278
- color: black !important;
279
- }
280
-
281
- .toolbar-wrap button {
282
- color: black !important;
283
- }
284
-
285
- .empty.wrap {
286
- color: black !important;
287
- }
288
-
289
-
290
- textarea {
291
- background-color : #f7f9f8 !important;
292
- color : #afb0b1 !important
293
- }
294
-
295
-
296
- input[data-testid="number-input"] {
297
- background-color : #f7f9f8 !important;
298
- color : black !important
299
- }
300
-
301
- tr > th {
302
- border-bottom : 1px solid black !important;
303
- }
304
-
305
- tr:hover {
306
- background: #f7f9f8 !important;
307
- }
308
-
309
- #component-19{
310
- justify-content: center !important;
311
- }
312
-
313
- #component-19 > button {
314
- flex: none !important;
315
- background-color : black !important;
316
- font-weight: bold !important;
317
-
318
- }
319
-
320
- .bold {
321
- font-weight: bold !important;
322
- }
323
-
324
- span[data-testid="block-info"]{
325
- color: black !important;
326
- font-weight: bold !important;
327
- }
328
-
329
- #component-14 > div {
330
- background-color : white !important;
331
-
332
- }
333
-
334
- button[aria-label="Clear"] {
335
- background-color : white !important;
336
- color: black !important;
337
-
338
- }
339
-
340
- #gallery_container {
341
- display: flex;
342
- flex-wrap: wrap;
343
- justify-content: start;
344
- }
345
-
346
- .image_gallery {
347
- margin-bottom: 1rem;
348
- margin-right: 1rem;
349
- }
350
-
351
- label[for='range_id_0'] > span > span {
352
- text-decoration: underline;
353
- }
354
-
355
- label[for='range_id_0'] > span > span {
356
- font-size: normal !important;
357
- }
358
-
359
- .underline {
360
- text-decoration: underline;
361
- }
362
-
363
-
364
- .mt-mb-1{
365
- margin-top: 1rem;
366
- margin-bottom: 1rem;
367
- }
368
-
369
- #gallery_container + div {
370
- visibility: hidden;
371
- height: 10px;
372
- }
373
-
374
- input[type=number][disabled] {
375
- background-color: rgb(247, 249, 248) !important;
376
- color: black !important;
377
- -webkit-text-fill-color: black !important;
378
- }
379
-
380
- #component-13 {
381
- display: flex;
382
- flex-direction: column;
383
- align-items: center;
384
- }
385
-
386
- """
387
-
388
-
389
- with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
390
- gr.HTML("<h1 class='black-text' style='text-align: center;'>Open Vocabulary Scene Sketch Semantic Understanding</div>")
391
- gr.HTML("<div class='black-text'></div>")
392
- # gr.HTML("<div class='black-text' style='text-align: center;'><a href='https://ahmedbourouis.github.io/ahmed-bourouis/'>Ahmed Bourouis</a>,<a href='https://profiles.stanford.edu/judith-fan'>Judith Ellen Fan</a>, <a href='https://yulia.gryaditskaya.com/'>Yulia Gryaditskaya</a></div>")
393
- gr.HTML("<div class='black-text' style='text-align: center;'>Ahmed Bourouis, Judith Ellen Fan, Yulia Gryaditskaya</div>")
394
- gr.HTML("<div class='black-text' style='text-align: center;' >CVPR, 2024</p>")
395
- gr.HTML("<div style='text-align: center;'><p><a href='https://ahmedbourouis.github.io/Scene_Sketch_Segmentation/'>Project page</a></p></div>")
396
-
397
-
398
- # gr.Markdown( "Scene Sketch Semantic Segmentation.", elem_classes=["black-txt" , "h1"] )
399
- # gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] )
400
- # gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] )
401
- # gr.Markdown( "")
402
-
403
-
404
- with gr.Row():
405
- with gr.Column():
406
- # in_image = gr.Image( label="Sketch", type="pil", sources="upload" , height=512 )
407
- in_canvas_image = gr.Sketchpad(
408
- # value=Image.new('RGB', (512, 512), color=(255, 255, 255)),
409
- brush=gr.Brush(colors=["#000000"], color_mode="fixed" , default_size=2),
410
- image_mode="RGBA",elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
411
- label="Sketch" , canvas_size=(512,512) ,sources=['upload'],
412
- interactive=True , layers= False, transforms=[]
413
- )
414
- query_selector = 'button[aria-label="Upload button"]'
415
-
416
- # with gr.Row():
417
- # segment_btn.click(fn=run, inputs=[in_image, in_textbox, in_slider], outputs=[out_image])
418
- upload_draw_btn = gr.HTML(f"""
419
- <div id="upload_draw_group" class="svelte-15lo0d8 stretch">
420
- <button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="upload_btn" onclick="return document.querySelector('.source-wrap button').click()"> Upload a new sketch</button>
421
- <button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="draw_btn" onclick="return document.querySelector('.controls-wrap button:nth-child(3)').click()"> Draw a new sketch</button>
422
- </div>
423
- """)
424
-
425
- # in_textbox = gr.Textbox( lines=2, elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ")
426
-
427
- with gr.Column():
428
- out_image = gr.Image( value=Image.new('RGB', (512, 512), color=(255, 255, 255)),
429
- elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
430
- type="pil", label="Segmented Sketch" ) #, height=512, width=512)
431
-
432
- # # gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Confidence:</span> Adjust AI agent confidence in guessing categories </div>")
433
- # in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
434
- # info="Adjust AI agent confidence in guessing categories",
435
- # label="Confidence:",
436
- # value=0.5 , interactive=True, step=0.05, minimum=0, maximum=1)
437
-
438
- with gr.Row():
439
- with gr.Column():
440
- in_textbox = gr.Textbox( lines=2, elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ")
441
-
442
- with gr.Column():
443
- # gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Confidence:</span> Adjust AI agent confidence in guessing categories </div>")
444
- in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
445
- info="Adjust AI agent confidence in guessing categories",
446
- label="Confidence:",
447
- value=0.5 , interactive=True, step=0.05, minimum=0, maximum=1)
448
-
449
- with gr.Row():
450
- segment_btn = gr.Button( 'Segment it* !' , elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" , 'bold' , 'mt-mb-1' ] , size="sm")
451
- segment_btn.click(fn=run, inputs=[in_canvas_image , in_textbox , in_slider ], outputs=[out_image])
452
- gallery_label = gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Gallery:</span> <span style='color: grey;'>you can click on any of the example sketches below to start segmenting them (or even drawing over them)</span> </div>")
453
-
454
- gallery= gr.HTML(f"""
455
- <div>
456
- {gr.Image( elem_classes=["image_gallery"] , label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_1.png', height=200, width=200)}
457
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_2.png', height=200, width=200)}
458
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_3.png', height=200, width=200)}
459
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004068.png', height=200, width=200)}
460
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004546.png', height=200, width=200)}
461
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000005076.png', height=200, width=200)}
462
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000006336.png', height=200, width=200)}
463
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000011766.png', height=200, width=200)}
464
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024458.png', height=200, width=200)}
465
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024931.png', height=200, width=200)}
466
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000034214.png', height=200, width=200)}
467
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000260974.png', height=200, width=200)}
468
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000268340.png', height=200, width=200)}
469
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000305414.png', height=200, width=200)}
470
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000484246.png', height=200, width=200)}
471
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000549338.png', height=200, width=200)}
472
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000038116.png', height=200, width=200)}
473
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000221509.png', height=200, width=200)}
474
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000246066.png', height=200, width=200)}
475
- {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000001611.png', height=200, width=200)}
476
- </div>
477
- """)
478
-
479
- examples = gr.Examples(
480
- examples_per_page=30,
481
- examples=[
482
- ['demo/sketch_1.png', 'giraffe looking at you', 0.6],
483
- ['demo/sketch_2.png', 'a kite flying in the sky', 0.6],
484
- ['demo/sketch_3.png', 'a girl playing', 0.6],
485
- ['demo/000000004068.png', 'car going so fast', 0.6],
486
- ['demo/000000004546.png', 'mountains in the background', 0.6],
487
- ['demo/000000005076.png', 'huge tree', 0.6],
488
- ['demo/000000006336.png', 'nice three sheeps', 0.6],
489
- ['demo/000000011766.png', 'bird minding its own business', 0.6],
490
- ['demo/000000024458.png', 'horse with a mask on', 0.6],
491
- ['demo/000000024931.png', 'some random person', 0.6],
492
- ['demo/000000034214.png', 'a cool kid on a skateboard', 0.6],
493
- ['demo/000000260974.png', 'the chair on the left', 0.6],
494
- ['demo/000000268340.png', 'stop sign', 0.6],
495
- ['demo/000000305414.png', 'a lonely elephant roaming around', 0.6],
496
- ['demo/000000484246.png', 'giraffe with a loong neck', 0.6],
497
- ['demo/000000549338.png', 'two donkeys trying to be smart', 0.6],
498
- ['demo/000000038116.png', 'a bat next to a kid', 0.6],
499
- ['demo/000000221509.png', 'funny looking cow', 0.6],
500
- ['demo/000000246066.png', 'bench in the park', 0.6],
501
- ['demo/000000001611.png', 'trees in the background', 0.6]
502
- ],
503
- inputs=[in_canvas_image, in_textbox , in_slider],
504
- fn=run,
505
- # cache_examples=True,
506
- )
507
-
508
- gr.HTML("<h5 class='black-text' style='text-align: left;'>*For optimal performance, use a commercial Nvidia RTX 3090 (this demo runs on a basic 2 vCPU).</h5>")
509
- gr.HTML("<h5 class='black-text' style='text-align: left;'>*We compare the entire caption to the scene sketch and threshold most similar pixels, without extracting individual classes.</h5>")
510
- demo.launch(share=False)
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision.transforms import InterpolationMode
5
+
6
+ BICUBIC = InterpolationMode.BICUBIC
7
+ from utils import setup, get_similarity_map,get_noun_phrase, rgb_to_hsv, hsv_to_rgb
8
+ from vpt.launch import default_argument_parser
9
+ from collections import OrderedDict
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import models
13
+ import string
14
+ import nltk
15
+ nltk.download('punkt')
16
+ nltk.download('averaged_perceptron_tagger')
17
+ from nltk.tokenize import word_tokenize
18
+ import torchvision
19
+
20
+ args = default_argument_parser().parse_args()
21
+ cfg = setup(args)
22
+
23
+ multi_classes = False
24
+
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
27
+ state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device)
28
+
29
+ # Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU
30
+ new_state_dict = OrderedDict()
31
+ for k, v in state_dict.items():
32
+ name = k[7:] # remove `module.`
33
+ new_state_dict[name] = v
34
+ Ours.load_state_dict(new_state_dict)
35
+ Ours.eval()
36
+ print("Model loaded successfully")
37
+
38
+
39
+ def run(sketch, caption, threshold, seed):
40
+ # select a random seed between 1 and 10 for the color
41
+ color_seed = np.random.randint(0, 9)
42
+
43
+ # set the condidate classes here
44
+ caption = caption.replace('\n',' ')
45
+ translator = str.maketrans('', '', string.punctuation)
46
+ caption = caption.translate(translator).lower()
47
+ words = word_tokenize(caption)
48
+ classes = get_noun_phrase(words)
49
+ if len(classes) ==0 or multi_classes == False:
50
+ classes = [caption]
51
+
52
+ # print(classes)
53
+
54
+ colors = plt.get_cmap("Set1").colors
55
+ classes_colors = colors[color_seed:len(classes)+color_seed]
56
+
57
+ sketch2 = sketch['composite']
58
+
59
+ # when the drawing tool is used
60
+ if sketch2[:,:,0:3].sum() == 0:
61
+ temp = sketch2[:,:,3]
62
+ # invert it
63
+ temp = 255 - temp
64
+ sketch2 = np.repeat(temp[:, :, np.newaxis], 3, axis=2)
65
+ temp2= np.full_like(temp, 255)
66
+ sketch2 = np.dstack((sketch2, temp2))
67
+
68
+ sketch2 = np.array(sketch2)
69
+ pil_img = Image.fromarray(sketch2).convert('RGB')
70
+ sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device)
71
+ # torchvision.utils.save_image(sketch_tensor, 'sketch_tensor.png')
72
+
73
+ with torch.no_grad():
74
+ text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device, no_module=True)
75
+ redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device, no_module=True)
76
+
77
+ num_of_tokens = 3
78
+ with torch.no_grad():
79
+ sketch_features = Ours.encode_image(sketch_tensor, layers=[12],
80
+ text_features=text_features - redundant_features, mode="test").squeeze(0)
81
+ sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True)
82
+ similarity = sketch_features @ (text_features - redundant_features).t()
83
+ patches_similarity = similarity[0, num_of_tokens + 1:, :]
84
+ pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0), pil_img.size).cpu()
85
+ # visualize_attention_maps_with_tokens(pixel_similarity, classes)
86
+ pixel_similarity[pixel_similarity < threshold] = 0
87
+ pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1)
88
+
89
+
90
+ # display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True)
91
+
92
+ # Find the class index with the highest similarity for each pixel
93
+ class_indices = np.argmax(pixel_similarity_array, axis=0)
94
+ # Create an HSV image placeholder
95
+ hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3)
96
+ hsv_image[..., 2] = 1 # Set Value to 1 for a white base
97
+
98
+ # Set the hue and value channels
99
+ for i, color in enumerate(classes_colors):
100
+ rgb_color = np.array(color).reshape(1, 1, 3)
101
+ hsv_color = rgb_to_hsv(rgb_color)
102
+ mask = class_indices == i
103
+ if i < len(classes): # For the first N-2 classes, set color based on similarity
104
+ hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue
105
+ hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation
106
+ hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value
107
+ else: # For the last two classes, set pixels to black
108
+ hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black
109
+ hsv_image[..., 1][mask] = 0 # Saturation set to 0
110
+ hsv_image[..., 2][mask] = 0 # Value set to 0, making it black
111
+
112
+ mask_tensor_org = sketch2[:,:,0]/255
113
+ hsv_image[mask_tensor_org==1] = [0,0,1]
114
+
115
+ # Convert the HSV image back to RGB to display and save
116
+ rgb_image = hsv_to_rgb(hsv_image)
117
+
118
+
119
+ if len(classes) > 1:
120
+ # Calculate centroids and render class names
121
+ for i, class_name in enumerate(classes):
122
+ mask = class_indices == i
123
+ if np.any(mask):
124
+ y, x = np.nonzero(mask)
125
+ centroid_x, centroid_y = np.mean(x), np.mean(y)
126
+ plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i]
127
+ bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
128
+
129
+ # Display the image with class names
130
+ plt.imshow(rgb_image)
131
+ plt.axis('off')
132
+ plt.tight_layout()
133
+ # plt.savefig(f'poster_vis/{classes[0]}.png', bbox_inches='tight', pad_inches=0)
134
+ plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
135
+ plt.close()
136
+
137
+ # rgb_image = Image.open(f'poster_vis/{classes[0]}.png')
138
+ rgb_image = Image.open('output.png')
139
+
140
+ return rgb_image
141
+
142
+
143
+
144
+ scripts = """
145
+ async () => {
146
+ // START gallery format
147
+ // Get all image elements with the class "image"
148
+ var images = document.querySelectorAll('.image_gallery');
149
+ var originalParent = document.querySelector('#component-0');
150
+ // Create a new parent div element
151
+ var parentDiv = document.createElement('div');
152
+ var beforeDiv= document.querySelector('.table-wrap').parentElement;
153
+ parentDiv.id = "gallery_container";
154
+
155
+ // Loop through each image, append it to the parent div, and remove it from its original parent
156
+ images.forEach(function(image , index ) {
157
+ // Append the image to the parent div
158
+ parentDiv.appendChild(image);
159
+
160
+ // Add click event listener to each image
161
+ image.addEventListener('click', function() {
162
+ let nth_ch = index+1
163
+ document.querySelector('.tr-body:nth-child(' + nth_ch + ')').click()
164
+ console.log('.tr-body:nth-child(' + nth_ch + ')');
165
+ });
166
+
167
+ // Remove the image from its original parent
168
+ });
169
+
170
+
171
+ // Get a reference to the original parent of the images
172
+ var originalParent = document.querySelector('#component-0');
173
+
174
+ // Append the new parent div to the original parent
175
+ originalParent.insertBefore(parentDiv, beforeDiv);
176
+
177
+ // END gallery format
178
+
179
+ // START confidence span
180
+
181
+ // Get the selected div (replace 'selectedDivId' with the actual ID of your div)
182
+ var selectedDiv = document.querySelector("label[for='range_id_0'] > span")
183
+
184
+ // Get the text content of the div
185
+ var textContent = selectedDiv.textContent;
186
+
187
+ // Find the text before the first colon ':'
188
+ var colonIndex = textContent.indexOf(':');
189
+ var textBeforeColon = textContent.substring(0, colonIndex);
190
+
191
+ // Wrap the text before colon with a span element
192
+ var spanElement = document.createElement('span');
193
+ spanElement.textContent = textBeforeColon;
194
+
195
+ // Replace the original text with the modified text containing the span
196
+ selectedDiv.innerHTML = textContent.replace(textBeforeColon, spanElement.outerHTML);
197
+
198
+ // START format the column names :
199
+ // Get all elements with the class "test_class"
200
+ var elements = document.querySelectorAll('.tr-head > th');
201
+
202
+ // Iterate over each element
203
+ elements.forEach(function(element) {
204
+ // Get the text content of the element
205
+ var text = element.textContent.trim();
206
+
207
+ // Remove ":" from the text
208
+ var wordWithoutColon = text.replace(':', '');
209
+
210
+ // Split the text into words
211
+ var words = wordWithoutColon.split(' ');
212
+
213
+ // Keep only the first word
214
+ var firstWord = words[0];
215
+
216
+ // Set the text content of the element to the first word
217
+ element.textContent = firstWord;
218
+ });
219
+
220
+ document.querySelector('input[type=number]').disabled = true;
221
+ }
222
+ """
223
+
224
+ css="""
225
+
226
+ gradio-app {
227
+ background-color: white !important;
228
+ }
229
+
230
+ .white-bg {
231
+ background-color: white !important;
232
+ }
233
+
234
+ .gray-border {
235
+ border: 1px solid dimgrey !important;
236
+ }
237
+
238
+ .border-radius {
239
+ border-radius: 8px !important;
240
+ }
241
+
242
+ .black-text {
243
+ color : black !important;
244
+ }
245
+
246
+ th {
247
+ color : black !important;
248
+
249
+ }
250
+
251
+ tr {
252
+ background-color: white !important;
253
+ color: black !important;
254
+ }
255
+
256
+ td {
257
+ border-bottom : 1px solid black !important;
258
+ }
259
+
260
+ label[data-testid="block-label"] {
261
+ background: white;
262
+ color: black;
263
+ font-weight: bold;
264
+ }
265
+
266
+ .controls-wrap button:disabled {
267
+ color: gray !important;
268
+ background-color: white !important;
269
+ }
270
+
271
+ .controls-wrap button:not(:disabled) {
272
+ color: black !important;
273
+ background-color: white !important;
274
+
275
+ }
276
+
277
+ .source-wrap button {
278
+ color: black !important;
279
+ }
280
+
281
+ .toolbar-wrap button {
282
+ color: black !important;
283
+ }
284
+
285
+ .empty.wrap {
286
+ color: black !important;
287
+ }
288
+
289
+
290
+ textarea {
291
+ background-color : #f7f9f8 !important;
292
+ color : #afb0b1 !important
293
+ }
294
+
295
+
296
+ input[data-testid="number-input"] {
297
+ background-color : #f7f9f8 !important;
298
+ color : black !important
299
+ }
300
+
301
+ tr > th {
302
+ border-bottom : 1px solid black !important;
303
+ }
304
+
305
+ tr:hover {
306
+ background: #f7f9f8 !important;
307
+ }
308
+
309
+ #component-19{
310
+ justify-content: center !important;
311
+ }
312
+
313
+ #component-19 > button {
314
+ flex: none !important;
315
+ background-color : black !important;
316
+ font-weight: bold !important;
317
+
318
+ }
319
+
320
+ .bold {
321
+ font-weight: bold !important;
322
+ }
323
+
324
+ span[data-testid="block-info"]{
325
+ color: black !important;
326
+ font-weight: bold !important;
327
+ }
328
+
329
+ #component-14 > div {
330
+ background-color : white !important;
331
+
332
+ }
333
+
334
+ button[aria-label="Clear"] {
335
+ background-color : white !important;
336
+ color: black !important;
337
+
338
+ }
339
+
340
+ #gallery_container {
341
+ display: flex;
342
+ flex-wrap: wrap;
343
+ justify-content: start;
344
+ }
345
+
346
+ .image_gallery {
347
+ margin-bottom: 1rem;
348
+ margin-right: 1rem;
349
+ }
350
+
351
+ label[for='range_id_0'] > span > span {
352
+ text-decoration: underline;
353
+ }
354
+
355
+ label[for='range_id_0'] > span > span {
356
+ font-size: normal !important;
357
+ }
358
+
359
+ .underline {
360
+ text-decoration: underline;
361
+ }
362
+
363
+
364
+ .mt-mb-1{
365
+ margin-top: 1rem;
366
+ margin-bottom: 1rem;
367
+ }
368
+
369
+ #gallery_container + div {
370
+ visibility: hidden;
371
+ height: 10px;
372
+ }
373
+
374
+ input[type=number][disabled] {
375
+ background-color: rgb(247, 249, 248) !important;
376
+ color: black !important;
377
+ -webkit-text-fill-color: black !important;
378
+ }
379
+
380
+ #component-13 {
381
+ display: flex;
382
+ flex-direction: column;
383
+ align-items: center;
384
+ }
385
+
386
+ """
387
+
388
+
389
+ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
390
+ gr.HTML("<h1 class='black-text' style='text-align: center;'>Open Vocabulary Scene Sketch Semantic Understanding</div>")
391
+ gr.HTML("<div class='black-text'></div>")
392
+ # gr.HTML("<div class='black-text' style='text-align: center;'><a href='https://ahmedbourouis.github.io/ahmed-bourouis/'>Ahmed Bourouis</a>,<a href='https://profiles.stanford.edu/judith-fan'>Judith Ellen Fan</a>, <a href='https://yulia.gryaditskaya.com/'>Yulia Gryaditskaya</a></div>")
393
+ gr.HTML("<div class='black-text' style='text-align: center;'>Ahmed Bourouis, Judith Ellen Fan, Yulia Gryaditskaya</div>")
394
+ gr.HTML("<div class='black-text' style='text-align: center;' >CVPR, 2024</p>")
395
+ gr.HTML("<div style='text-align: center;'><p><a href='https://ahmedbourouis.github.io/Scene_Sketch_Segmentation/'>Project page</a></p></div>")
396
+
397
+
398
+ # gr.Markdown( "Scene Sketch Semantic Segmentation.", elem_classes=["black-txt" , "h1"] )
399
+ # gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] )
400
+ # gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] )
401
+ # gr.Markdown( "")
402
+
403
+
404
+ with gr.Row():
405
+ with gr.Column():
406
+ # in_image = gr.Image( label="Sketch", type="pil", sources="upload" , height=512 )
407
+ in_canvas_image = gr.Sketchpad(
408
+ # value=Image.new('RGB', (512, 512), color=(255, 255, 255)),
409
+ brush=gr.Brush(colors=["#000000"], color_mode="fixed" , default_size=2),
410
+ image_mode="RGBA",elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
411
+ label="Sketch" , canvas_size=(512,512) ,sources=['upload'],
412
+ interactive=True , layers= False, transforms=[]
413
+ )
414
+ query_selector = 'button[aria-label="Upload button"]'
415
+
416
+ # with gr.Row():
417
+ # segment_btn.click(fn=run, inputs=[in_image, in_textbox, in_slider], outputs=[out_image])
418
+ upload_draw_btn = gr.HTML(f"""
419
+ <div id="upload_draw_group" class="svelte-15lo0d8 stretch">
420
+ <button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="upload_btn" onclick="return document.querySelector('.source-wrap button').click()"> Upload a new sketch</button>
421
+ <button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="draw_btn" onclick="return document.querySelector('.controls-wrap button:nth-child(3)').click()"> Draw a new sketch</button>
422
+ </div>
423
+ """)
424
+
425
+ # in_textbox = gr.Textbox( lines=2, elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ")
426
+
427
+ with gr.Column():
428
+ out_image = gr.Image( value=Image.new('RGB', (512, 512), color=(255, 255, 255)),
429
+ elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
430
+ type="pil", label="Segmented Sketch" ) #, height=512, width=512)
431
+
432
+ # # gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Confidence:</span> Adjust AI agent confidence in guessing categories </div>")
433
+ # in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
434
+ # info="Adjust AI agent confidence in guessing categories",
435
+ # label="Confidence:",
436
+ # value=0.5 , interactive=True, step=0.05, minimum=0, maximum=1)
437
+
438
+ with gr.Row():
439
+ with gr.Column():
440
+ in_textbox = gr.Textbox( lines=2, elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ")
441
+
442
+ with gr.Column():
443
+ # gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Confidence:</span> Adjust AI agent confidence in guessing categories </div>")
444
+ in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,
445
+ info="Adjust AI agent confidence in guessing categories",
446
+ label="Confidence:",
447
+ value=0.5 , interactive=True, step=0.05, minimum=0, maximum=1)
448
+
449
+ with gr.Row():
450
+ segment_btn = gr.Button( 'Segment it<sup>*</sup> !' , elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" , 'bold' , 'mt-mb-1' ] , size="sm")
451
+ segment_btn.click(fn=run, inputs=[in_canvas_image , in_textbox , in_slider ], outputs=[out_image])
452
+ gallery_label = gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Gallery:</span> <span style='color: grey;'>you can click on any of the example sketches below to start segmenting them (or even drawing over them)</span> </div>")
453
+
454
+ gallery= gr.HTML(f"""
455
+ <div>
456
+ {gr.Image( elem_classes=["image_gallery"] , label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_1.png', height=200, width=200)}
457
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_2.png', height=200, width=200)}
458
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_3.png', height=200, width=200)}
459
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004068.png', height=200, width=200)}
460
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004546.png', height=200, width=200)}
461
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000005076.png', height=200, width=200)}
462
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000006336.png', height=200, width=200)}
463
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000011766.png', height=200, width=200)}
464
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024458.png', height=200, width=200)}
465
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024931.png', height=200, width=200)}
466
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000034214.png', height=200, width=200)}
467
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000260974.png', height=200, width=200)}
468
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000268340.png', height=200, width=200)}
469
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000305414.png', height=200, width=200)}
470
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000484246.png', height=200, width=200)}
471
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000549338.png', height=200, width=200)}
472
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000038116.png', height=200, width=200)}
473
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000221509.png', height=200, width=200)}
474
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000246066.png', height=200, width=200)}
475
+ {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000001611.png', height=200, width=200)}
476
+ </div>
477
+ """)
478
+
479
+ examples = gr.Examples(
480
+ examples_per_page=30,
481
+ examples=[
482
+ ['demo/sketch_1.png', 'giraffe looking at you', 0.6],
483
+ ['demo/sketch_2.png', 'a kite flying in the sky', 0.6],
484
+ ['demo/sketch_3.png', 'a girl playing', 0.6],
485
+ ['demo/000000004068.png', 'car going so fast', 0.6],
486
+ ['demo/000000004546.png', 'mountains in the background', 0.6],
487
+ ['demo/000000005076.png', 'huge tree', 0.6],
488
+ ['demo/000000006336.png', 'nice three sheeps', 0.6],
489
+ ['demo/000000011766.png', 'bird minding its own business', 0.6],
490
+ ['demo/000000024458.png', 'horse with a mask on', 0.6],
491
+ ['demo/000000024931.png', 'some random person', 0.6],
492
+ ['demo/000000034214.png', 'a cool kid on a skateboard', 0.6],
493
+ ['demo/000000260974.png', 'the chair on the left', 0.6],
494
+ ['demo/000000268340.png', 'stop sign', 0.6],
495
+ ['demo/000000305414.png', 'a lonely elephant roaming around', 0.6],
496
+ ['demo/000000484246.png', 'giraffe with a loong neck', 0.6],
497
+ ['demo/000000549338.png', 'two donkeys trying to be smart', 0.6],
498
+ ['demo/000000038116.png', 'a bat next to a kid', 0.6],
499
+ ['demo/000000221509.png', 'funny looking cow', 0.6],
500
+ ['demo/000000246066.png', 'bench in the park', 0.6],
501
+ ['demo/000000001611.png', 'trees in the background', 0.6]
502
+ ],
503
+ inputs=[in_canvas_image, in_textbox , in_slider],
504
+ fn=run,
505
+ # cache_examples=True,
506
+ )
507
+
508
+ gr.HTML("<h5 class='black-text' style='text-align: left;'>*This demo runs on a basic 2 vCPU. For instant segmentation, use a commercial Nvidia RTX 3090 GPU (t</h5>")
509
+ gr.HTML("<h5 class='black-text' style='text-align: left;'>*We compare the entire caption to the scene sketch and threshold most similar pixels, without extracting individual classes.</h5>")
510
+ demo.launch(share=False)