Spaces:
Running
Running
multi-categories
Browse files- __pycache__/utils.cpython-38.pyc +0 -0
- app.py +18 -12
- output.png +0 -0
- utils.py +40 -8
__pycache__/utils.cpython-38.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-38.pyc and b/__pycache__/utils.cpython-38.pyc differ
|
|
app.py
CHANGED
@@ -4,12 +4,14 @@ import torch
|
|
4 |
from torchvision.transforms import InterpolationMode
|
5 |
|
6 |
BICUBIC = InterpolationMode.BICUBIC
|
7 |
-
from utils import setup, get_similarity_map, display_segmented_sketch
|
8 |
from vpt.launch import default_argument_parser
|
9 |
from collections import OrderedDict
|
10 |
import numpy as np
|
11 |
import matplotlib.pyplot as plt
|
12 |
import models
|
|
|
|
|
13 |
import torchvision
|
14 |
|
15 |
args = default_argument_parser().parse_args()
|
@@ -31,10 +33,18 @@ print("Model loaded successfully")
|
|
31 |
|
32 |
def run(sketch, caption, threshold, seed):
|
33 |
# set the condidate classes here
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
sketch2 = sketch['composite']
|
40 |
|
@@ -414,7 +424,7 @@ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
|
|
414 |
examples_per_page=30,
|
415 |
examples=[
|
416 |
['demo/sketch_1.png', 'giraffe looking at you', 0.6],
|
417 |
-
['demo/sketch_2.png', '
|
418 |
['demo/sketch_3.png', 'a girl playing', 0.6],
|
419 |
['demo/000000004068.png', 'car going so fast', 0.6],
|
420 |
['demo/000000004546.png', 'mountains in the background', 0.6],
|
@@ -429,7 +439,7 @@ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
|
|
429 |
['demo/000000305414.png', 'a lonely elephant roaming around', 0.6],
|
430 |
['demo/000000484246.png', 'giraffe with a loong neck', 0.6],
|
431 |
['demo/000000549338.png', 'two donkeys trying to be smart', 0.6],
|
432 |
-
['demo/000000038116.png', 'a bat
|
433 |
['demo/000000221509.png', 'funny looking cow', 0.6],
|
434 |
['demo/000000246066.png', 'bench in the park', 0.6],
|
435 |
['demo/000000001611.png', 'trees in the background', 0.6]
|
@@ -439,8 +449,4 @@ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
|
|
439 |
# cache_examples=True,
|
440 |
)
|
441 |
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
demo.launch(share=False, )
|
|
|
4 |
from torchvision.transforms import InterpolationMode
|
5 |
|
6 |
BICUBIC = InterpolationMode.BICUBIC
|
7 |
+
from utils import setup, get_similarity_map, display_segmented_sketch,get_noun_phrase
|
8 |
from vpt.launch import default_argument_parser
|
9 |
from collections import OrderedDict
|
10 |
import numpy as np
|
11 |
import matplotlib.pyplot as plt
|
12 |
import models
|
13 |
+
import string
|
14 |
+
import nltk
|
15 |
import torchvision
|
16 |
|
17 |
args = default_argument_parser().parse_args()
|
|
|
33 |
|
34 |
def run(sketch, caption, threshold, seed):
|
35 |
# set the condidate classes here
|
36 |
+
caption = caption.replace('\n',' ')
|
37 |
+
translator = str.maketrans('', '', string.punctuation)
|
38 |
+
caption = caption.translate(translator).lower()
|
39 |
+
words = nltk.word_tokenize(caption)
|
40 |
+
classes = get_noun_phrase(words)
|
41 |
+
if len(classes) ==0:
|
42 |
+
classes = [caption]
|
43 |
+
|
44 |
+
# print(classes)
|
45 |
+
|
46 |
+
colors = plt.get_cmap("Set1").colors
|
47 |
+
classes_colors = colors[:len(classes)]
|
48 |
|
49 |
sketch2 = sketch['composite']
|
50 |
|
|
|
424 |
examples_per_page=30,
|
425 |
examples=[
|
426 |
['demo/sketch_1.png', 'giraffe looking at you', 0.6],
|
427 |
+
['demo/sketch_2.png', 'a kite flying in the sky', 0.6],
|
428 |
['demo/sketch_3.png', 'a girl playing', 0.6],
|
429 |
['demo/000000004068.png', 'car going so fast', 0.6],
|
430 |
['demo/000000004546.png', 'mountains in the background', 0.6],
|
|
|
439 |
['demo/000000305414.png', 'a lonely elephant roaming around', 0.6],
|
440 |
['demo/000000484246.png', 'giraffe with a loong neck', 0.6],
|
441 |
['demo/000000549338.png', 'two donkeys trying to be smart', 0.6],
|
442 |
+
['demo/000000038116.png', 'a bat next to a kid', 0.6],
|
443 |
['demo/000000221509.png', 'funny looking cow', 0.6],
|
444 |
['demo/000000246066.png', 'bench in the park', 0.6],
|
445 |
['demo/000000001611.png', 'trees in the background', 0.6]
|
|
|
449 |
# cache_examples=True,
|
450 |
)
|
451 |
|
452 |
+
demo.launch(share=False)
|
|
|
|
|
|
|
|
output.png
CHANGED
utils.py
CHANGED
@@ -10,11 +10,41 @@ from vpt.src.utils.file_io import PathManager
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
|
12 |
import warnings
|
|
|
13 |
|
14 |
|
15 |
warnings.filterwarnings("ignore")
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def setup(args):
|
19 |
"""
|
20 |
Create configs and perform basic setups.
|
@@ -90,14 +120,16 @@ def display_segmented_sketch(pixel_similarity_array,binary_sketch,classes,classe
|
|
90 |
# Convert the HSV image back to RGB to display and save
|
91 |
rgb_image = hsv_to_rgb(hsv_image)
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
101 |
|
102 |
|
103 |
# Display the image with class names
|
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
|
12 |
import warnings
|
13 |
+
import nltk
|
14 |
|
15 |
|
16 |
warnings.filterwarnings("ignore")
|
17 |
|
18 |
|
19 |
+
def get_noun_phrase(tokenized):
|
20 |
+
# Taken from Su Nam Kim Paper...
|
21 |
+
grammar = r"""
|
22 |
+
NBAR:
|
23 |
+
{<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
|
24 |
+
NP:
|
25 |
+
{<NBAR>}
|
26 |
+
{<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
|
27 |
+
"""
|
28 |
+
chunker = nltk.RegexpParser(grammar)
|
29 |
+
|
30 |
+
chunked = chunker.parse(nltk.pos_tag(tokenized))
|
31 |
+
continuous_chunk = []
|
32 |
+
current_chunk = []
|
33 |
+
|
34 |
+
for subtree in chunked:
|
35 |
+
if isinstance(subtree, nltk.Tree):
|
36 |
+
current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
|
37 |
+
elif current_chunk:
|
38 |
+
named_entity = ' '.join(current_chunk)
|
39 |
+
if named_entity not in continuous_chunk:
|
40 |
+
continuous_chunk.append(named_entity)
|
41 |
+
current_chunk = []
|
42 |
+
else:
|
43 |
+
continue
|
44 |
+
|
45 |
+
return continuous_chunk
|
46 |
+
|
47 |
+
|
48 |
def setup(args):
|
49 |
"""
|
50 |
Create configs and perform basic setups.
|
|
|
120 |
# Convert the HSV image back to RGB to display and save
|
121 |
rgb_image = hsv_to_rgb(hsv_image)
|
122 |
|
123 |
+
|
124 |
+
if len(classes) > 1:
|
125 |
+
# Calculate centroids and render class names
|
126 |
+
for i, class_name in enumerate(classes):
|
127 |
+
mask = class_indices == i
|
128 |
+
if np.any(mask):
|
129 |
+
y, x = np.nonzero(mask)
|
130 |
+
centroid_x, centroid_y = np.mean(x), np.mean(y)
|
131 |
+
plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i]
|
132 |
+
bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
|
133 |
|
134 |
|
135 |
# Display the image with class names
|