Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,18 +4,42 @@ from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDeco
|
|
4 |
|
5 |
# Setup device, model, tokenizer, and feature extractor
|
6 |
device = 'cpu'
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Prediction function
|
13 |
def predict(image, max_length=128):
|
|
|
|
|
14 |
image = image.convert('RGB')
|
15 |
-
pixel_values =
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
# Define input and output components
|
21 |
input_component = gr.components.Image(label="Upload any Image", type="pil")
|
|
|
4 |
|
5 |
# Setup device, model, tokenizer, and feature extractor
|
6 |
device = 'cpu'
|
7 |
+
|
8 |
+
|
9 |
+
model_checkpoint1 = "Stoneman/IG-caption-generator-vit-gpt2-last-block"
|
10 |
+
feature_extractor1 = ViTImageProcessor.from_pretrained(model_checkpoint1)
|
11 |
+
tokenizer1 = GPT2TokenizerFast.from_pretrained(model_checkpoint1)
|
12 |
+
model1 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint1).to(device)
|
13 |
+
|
14 |
+
model_checkpoint2 = "Stoneman/IG-caption-generator-vit-gpt2-all"
|
15 |
+
model2 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint2).to(device)
|
16 |
+
|
17 |
+
model_checkpoint3 = "Stoneman/IG-caption-generator-nlpconnect-last-block"
|
18 |
+
model3 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint3).to(device)
|
19 |
+
|
20 |
+
model_checkpoint4 = "Stoneman/IG-caption-generator-nlpconnect-all"
|
21 |
+
model4 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint4).to(device)
|
22 |
+
|
23 |
+
models = {
|
24 |
+
1: model1,
|
25 |
+
2: model2,
|
26 |
+
3: model3,
|
27 |
+
4: model4
|
28 |
+
}
|
29 |
|
30 |
# Prediction function
|
31 |
def predict(image, max_length=128):
|
32 |
+
captions = {}
|
33 |
+
|
34 |
image = image.convert('RGB')
|
35 |
+
pixel_values = feature_extractor1(images=image, return_tensors="pt").pixel_values.to(device)
|
36 |
+
for i in range(1,5):
|
37 |
+
caption_ids = models[i].generate(pixel_values, max_length=max_length)[0]
|
38 |
+
caption_text = tokenizer1.decode(caption_ids, skip_special_tokens=True)
|
39 |
+
captions[i] = caption_text
|
40 |
+
# Return a single string with all captions
|
41 |
+
return '\n\n'.join(f'Model {i}: {caption}' for i, caption in captions.items())
|
42 |
+
|
43 |
|
44 |
# Define input and output components
|
45 |
input_component = gr.components.Image(label="Upload any Image", type="pil")
|