ahmedbrs commited on
Commit
53e6d40
1 Parent(s): d9aaca7

add multi-class

Browse files
Files changed (3) hide show
  1. app.py +28 -5
  2. output.png +0 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -17,10 +17,31 @@ nltk.download('averaged_perceptron_tagger')
17
  from nltk.tokenize import word_tokenize
18
  import torchvision
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  args = default_argument_parser().parse_args()
21
  cfg = setup(args)
22
 
23
- multi_classes = False
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
@@ -42,10 +63,12 @@ def run(sketch, caption, threshold, seed):
42
 
43
  # set the condidate classes here
44
  caption = caption.replace('\n',' ')
45
- translator = str.maketrans('', '', string.punctuation)
46
- caption = caption.translate(translator).lower()
47
- words = word_tokenize(caption)
48
- classes = get_noun_phrase(words)
 
 
49
  if len(classes) ==0 or multi_classes == False:
50
  classes = [caption]
51
 
 
17
  from nltk.tokenize import word_tokenize
18
  import torchvision
19
 
20
+ import spacy
21
+
22
+ # download the model
23
+ spacy.cli.download("en_core_web_sm")
24
+
25
+ # Load spaCy model
26
+ nlp = spacy.load("en_core_web_sm")
27
+
28
+ def extract_objects(prompt):
29
+ doc = nlp(prompt)
30
+ # Extract object nouns (including proper nouns and compound nouns)
31
+ objects = set()
32
+ for token in doc:
33
+ # Check if the token is a noun or part of a named entity
34
+ if token.pos_ in {"NOUN", "PROPN"} or token.ent_type_:
35
+ objects.add(token.text)
36
+ # Check if the token is part of a compound noun
37
+ if token.dep_ in {"compound"}:
38
+ objects.add(token.head.text)
39
+ return list(objects)
40
+
41
  args = default_argument_parser().parse_args()
42
  cfg = setup(args)
43
 
44
+ multi_classes = True
45
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
 
63
 
64
  # set the condidate classes here
65
  caption = caption.replace('\n',' ')
66
+ classes = extract_objects(caption)
67
+ # translator = str.maketrans('', '', string.punctuation)
68
+ # caption = caption.translate(translator).lower()
69
+ # words = word_tokenize(caption)
70
+ # classes = get_noun_phrase(words)
71
+ # print(classes)
72
  if len(classes) ==0 or multi_classes == False:
73
  classes = [caption]
74
 
output.png CHANGED
requirements.txt CHANGED
@@ -10,4 +10,5 @@ iopath
10
  ftfy
11
  fvcore
12
  regex
13
- nltk
 
 
10
  ftfy
11
  fvcore
12
  regex
13
+ nltk
14
+ spacy