File size: 2,147 Bytes
afbc5a8
e1ddabe
 
afbc5a8
e1ddabe
b6c16c3
 
e1ddabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afbc5a8
 
0259384
 
e1ddabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel 

# Setup device, model, tokenizer, and feature extractor
device ='cpu'

model_checkpoint1 = "Stoneman/IG-caption-generator-vit-gpt2-last-block"
feature_extractor1 = ViTImageProcessor.from_pretrained(model_checkpoint1)
tokenizer1 = GPT2TokenizerFast.from_pretrained(model_checkpoint1)
model1 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint1).to(device)

model_checkpoint2 = "Stoneman/IG-caption-generator-vit-gpt2-all"
model2 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint2).to(device)

model_checkpoint3 = "Stoneman/IG-caption-generator-nlpconnect-last-block"
model3 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint3).to(device)

model_checkpoint4 = "Stoneman/IG-caption-generator-nlpconnect-all"
model4 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint4).to(device)

models = {
    1: model1,
    2: model2,
    3: model3,
    4: model4
}

# Prediction function
def predict(image, max_length=128):
    captions = {}

    image = image.convert('RGB')
    pixel_values = feature_extractor1(images=image, return_tensors="pt").pixel_values.to(device)
    for i in range(1,5):
        caption_ids = models[i].generate(pixel_values, max_length=max_length)[0]
        caption_text = tokenizer1.decode(caption_ids, skip_special_tokens=True)
        captions[i] = caption_text
    # Return a single string with all captions
    return '\n\n'.join(f'Model {i}: {caption}' for i, caption in captions.items())


# Define input and output components
input_component = gr.components.Image(label="Upload any Image", type="pil")
output_component = gr.components.Textbox(label="Captions")

# Example images
# examples = [f"example{i}.JPG" for i in range(1, 10)]
examples = ['example1.JPG']

# Interface
title = "IG-caption-generator"
description = "Made by: Jiayu Shi"
interface = gr.Interface(
    fn=predict,
    description=description,
    inputs=input_component,
    outputs=output_component,
    examples=examples,
    title=title,
)

# Launch interface
interface.launch(debug=True)