ahmedbrs commited on
Commit
37b5ba0
1 Parent(s): c8da19b

multi-categories

Browse files
Files changed (4) hide show
  1. __pycache__/utils.cpython-38.pyc +0 -0
  2. app.py +18 -12
  3. output.png +0 -0
  4. utils.py +40 -8
__pycache__/utils.cpython-38.pyc CHANGED
Binary files a/__pycache__/utils.cpython-38.pyc and b/__pycache__/utils.cpython-38.pyc differ
 
app.py CHANGED
@@ -4,12 +4,14 @@ import torch
4
  from torchvision.transforms import InterpolationMode
5
 
6
  BICUBIC = InterpolationMode.BICUBIC
7
- from utils import setup, get_similarity_map, display_segmented_sketch
8
  from vpt.launch import default_argument_parser
9
  from collections import OrderedDict
10
  import numpy as np
11
  import matplotlib.pyplot as plt
12
  import models
 
 
13
  import torchvision
14
 
15
  args = default_argument_parser().parse_args()
@@ -31,10 +33,18 @@ print("Model loaded successfully")
31
 
32
  def run(sketch, caption, threshold, seed):
33
  # set the condidate classes here
34
- classes = [caption]
35
-
36
- colors = plt.get_cmap("tab10").colors
37
- classes_colors = colors[3:len(classes) + 3]
 
 
 
 
 
 
 
 
38
 
39
  sketch2 = sketch['composite']
40
 
@@ -414,7 +424,7 @@ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
414
  examples_per_page=30,
415
  examples=[
416
  ['demo/sketch_1.png', 'giraffe looking at you', 0.6],
417
- ['demo/sketch_2.png', 'tree on the right', 0.6],
418
  ['demo/sketch_3.png', 'a girl playing', 0.6],
419
  ['demo/000000004068.png', 'car going so fast', 0.6],
420
  ['demo/000000004546.png', 'mountains in the background', 0.6],
@@ -429,7 +439,7 @@ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
429
  ['demo/000000305414.png', 'a lonely elephant roaming around', 0.6],
430
  ['demo/000000484246.png', 'giraffe with a loong neck', 0.6],
431
  ['demo/000000549338.png', 'two donkeys trying to be smart', 0.6],
432
- ['demo/000000038116.png', 'a bat on the left', 0.6],
433
  ['demo/000000221509.png', 'funny looking cow', 0.6],
434
  ['demo/000000246066.png', 'bench in the park', 0.6],
435
  ['demo/000000001611.png', 'trees in the background', 0.6]
@@ -439,8 +449,4 @@ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo:
439
  # cache_examples=True,
440
  )
441
 
442
-
443
-
444
-
445
-
446
- demo.launch(share=False, )
 
4
  from torchvision.transforms import InterpolationMode
5
 
6
  BICUBIC = InterpolationMode.BICUBIC
7
+ from utils import setup, get_similarity_map, display_segmented_sketch,get_noun_phrase
8
  from vpt.launch import default_argument_parser
9
  from collections import OrderedDict
10
  import numpy as np
11
  import matplotlib.pyplot as plt
12
  import models
13
+ import string
14
+ import nltk
15
  import torchvision
16
 
17
  args = default_argument_parser().parse_args()
 
33
 
34
  def run(sketch, caption, threshold, seed):
35
  # set the condidate classes here
36
+ caption = caption.replace('\n',' ')
37
+ translator = str.maketrans('', '', string.punctuation)
38
+ caption = caption.translate(translator).lower()
39
+ words = nltk.word_tokenize(caption)
40
+ classes = get_noun_phrase(words)
41
+ if len(classes) ==0:
42
+ classes = [caption]
43
+
44
+ # print(classes)
45
+
46
+ colors = plt.get_cmap("Set1").colors
47
+ classes_colors = colors[:len(classes)]
48
 
49
  sketch2 = sketch['composite']
50
 
 
424
  examples_per_page=30,
425
  examples=[
426
  ['demo/sketch_1.png', 'giraffe looking at you', 0.6],
427
+ ['demo/sketch_2.png', 'a kite flying in the sky', 0.6],
428
  ['demo/sketch_3.png', 'a girl playing', 0.6],
429
  ['demo/000000004068.png', 'car going so fast', 0.6],
430
  ['demo/000000004546.png', 'mountains in the background', 0.6],
 
439
  ['demo/000000305414.png', 'a lonely elephant roaming around', 0.6],
440
  ['demo/000000484246.png', 'giraffe with a loong neck', 0.6],
441
  ['demo/000000549338.png', 'two donkeys trying to be smart', 0.6],
442
+ ['demo/000000038116.png', 'a bat next to a kid', 0.6],
443
  ['demo/000000221509.png', 'funny looking cow', 0.6],
444
  ['demo/000000246066.png', 'bench in the park', 0.6],
445
  ['demo/000000001611.png', 'trees in the background', 0.6]
 
449
  # cache_examples=True,
450
  )
451
 
452
+ demo.launch(share=False)
 
 
 
 
output.png CHANGED
utils.py CHANGED
@@ -10,11 +10,41 @@ from vpt.src.utils.file_io import PathManager
10
  import matplotlib.pyplot as plt
11
  from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
12
  import warnings
 
13
 
14
 
15
  warnings.filterwarnings("ignore")
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def setup(args):
19
  """
20
  Create configs and perform basic setups.
@@ -90,14 +120,16 @@ def display_segmented_sketch(pixel_similarity_array,binary_sketch,classes,classe
90
  # Convert the HSV image back to RGB to display and save
91
  rgb_image = hsv_to_rgb(hsv_image)
92
 
93
- # # Calculate centroids and render class names
94
- # for i, class_name in enumerate(classes):
95
- # mask = class_indices == i
96
- # if np.any(mask):
97
- # y, x = np.nonzero(mask)
98
- # centroid_x, centroid_y = np.mean(x), np.mean(y)
99
- # plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=14, # color=classes_colors[i]
100
- # bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
 
 
101
 
102
 
103
  # Display the image with class names
 
10
  import matplotlib.pyplot as plt
11
  from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
12
  import warnings
13
+ import nltk
14
 
15
 
16
  warnings.filterwarnings("ignore")
17
 
18
 
19
+ def get_noun_phrase(tokenized):
20
+ # Taken from Su Nam Kim Paper...
21
+ grammar = r"""
22
+ NBAR:
23
+ {<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
24
+ NP:
25
+ {<NBAR>}
26
+ {<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
27
+ """
28
+ chunker = nltk.RegexpParser(grammar)
29
+
30
+ chunked = chunker.parse(nltk.pos_tag(tokenized))
31
+ continuous_chunk = []
32
+ current_chunk = []
33
+
34
+ for subtree in chunked:
35
+ if isinstance(subtree, nltk.Tree):
36
+ current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
37
+ elif current_chunk:
38
+ named_entity = ' '.join(current_chunk)
39
+ if named_entity not in continuous_chunk:
40
+ continuous_chunk.append(named_entity)
41
+ current_chunk = []
42
+ else:
43
+ continue
44
+
45
+ return continuous_chunk
46
+
47
+
48
  def setup(args):
49
  """
50
  Create configs and perform basic setups.
 
120
  # Convert the HSV image back to RGB to display and save
121
  rgb_image = hsv_to_rgb(hsv_image)
122
 
123
+
124
+ if len(classes) > 1:
125
+ # Calculate centroids and render class names
126
+ for i, class_name in enumerate(classes):
127
+ mask = class_indices == i
128
+ if np.any(mask):
129
+ y, x = np.nonzero(mask)
130
+ centroid_x, centroid_y = np.mean(x), np.mean(y)
131
+ plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i]
132
+ bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
133
 
134
 
135
  # Display the image with class names