Spaces:
Running
Running
add multi-class
Browse files- app.py +28 -5
- output.png +0 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -17,10 +17,31 @@ nltk.download('averaged_perceptron_tagger')
|
|
17 |
from nltk.tokenize import word_tokenize
|
18 |
import torchvision
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
args = default_argument_parser().parse_args()
|
21 |
cfg = setup(args)
|
22 |
|
23 |
-
multi_classes =
|
24 |
|
25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
|
@@ -42,10 +63,12 @@ def run(sketch, caption, threshold, seed):
|
|
42 |
|
43 |
# set the condidate classes here
|
44 |
caption = caption.replace('\n',' ')
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
if len(classes) ==0 or multi_classes == False:
|
50 |
classes = [caption]
|
51 |
|
|
|
17 |
from nltk.tokenize import word_tokenize
|
18 |
import torchvision
|
19 |
|
20 |
+
import spacy
|
21 |
+
|
22 |
+
# download the model
|
23 |
+
spacy.cli.download("en_core_web_sm")
|
24 |
+
|
25 |
+
# Load spaCy model
|
26 |
+
nlp = spacy.load("en_core_web_sm")
|
27 |
+
|
28 |
+
def extract_objects(prompt):
|
29 |
+
doc = nlp(prompt)
|
30 |
+
# Extract object nouns (including proper nouns and compound nouns)
|
31 |
+
objects = set()
|
32 |
+
for token in doc:
|
33 |
+
# Check if the token is a noun or part of a named entity
|
34 |
+
if token.pos_ in {"NOUN", "PROPN"} or token.ent_type_:
|
35 |
+
objects.add(token.text)
|
36 |
+
# Check if the token is part of a compound noun
|
37 |
+
if token.dep_ in {"compound"}:
|
38 |
+
objects.add(token.head.text)
|
39 |
+
return list(objects)
|
40 |
+
|
41 |
args = default_argument_parser().parse_args()
|
42 |
cfg = setup(args)
|
43 |
|
44 |
+
multi_classes = True
|
45 |
|
46 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
|
|
|
63 |
|
64 |
# set the condidate classes here
|
65 |
caption = caption.replace('\n',' ')
|
66 |
+
classes = extract_objects(caption)
|
67 |
+
# translator = str.maketrans('', '', string.punctuation)
|
68 |
+
# caption = caption.translate(translator).lower()
|
69 |
+
# words = word_tokenize(caption)
|
70 |
+
# classes = get_noun_phrase(words)
|
71 |
+
# print(classes)
|
72 |
if len(classes) ==0 or multi_classes == False:
|
73 |
classes = [caption]
|
74 |
|
output.png
CHANGED
requirements.txt
CHANGED
@@ -10,4 +10,5 @@ iopath
|
|
10 |
ftfy
|
11 |
fvcore
|
12 |
regex
|
13 |
-
nltk
|
|
|
|
10 |
ftfy
|
11 |
fvcore
|
12 |
regex
|
13 |
+
nltk
|
14 |
+
spacy
|