Stoneman commited on
Commit
e1ddabe
1 Parent(s): 1e1a274

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -8
app.py CHANGED
@@ -1,13 +1,62 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- pipe = pipeline(task="image-to-text",
5
- model="Stoneman/IG-caption-generator-nlpconnect-last-block")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Example images
8
  examples = [f"example{i}.JPG" for i in range(1, 10)]
9
- gr.Interface.from_pipeline(pipe,
10
- title="IG-caption-generator",
11
- description="IG caption generator using ViT and GPT2.",
12
- examples=examples
13
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
4
 
5
+ # Setup device, model, tokenizer, and feature extractor
6
+ model_checkpoint1 = "Stoneman/IG-caption-generator-vit-gpt2-last-block"
7
+ feature_extractor1 = ViTImageProcessor.from_pretrained(model_checkpoint1)
8
+ tokenizer1 = GPT2TokenizerFast.from_pretrained(model_checkpoint1)
9
+ model1 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint1).to(device)
10
+
11
+ model_checkpoint2 = "Stoneman/IG-caption-generator-vit-gpt2-all"
12
+ model2 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint2).to(device)
13
+
14
+ model_checkpoint3 = "Stoneman/IG-caption-generator-nlpconnect-last-block"
15
+ model3 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint3).to(device)
16
+
17
+ model_checkpoint4 = "Stoneman/IG-caption-generator-nlpconnect-all"
18
+ model4 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint4).to(device)
19
+
20
+ models = {
21
+ 1: model1,
22
+ 2: model2,
23
+ 3: model3,
24
+ 4: model4
25
+ }
26
+
27
+ # Prediction function
28
+ def predict(image, max_length=128):
29
+ captions = {}
30
+
31
+ image = image.convert('RGB')
32
+ pixel_values = feature_extractor1(images=image, return_tensors="pt").pixel_values.to(device)
33
+ for i in range(1,5):
34
+ caption_ids = models[i].generate(pixel_values, max_length=max_length)[0]
35
+ caption_text = tokenizer1.decode(caption_ids, skip_special_tokens=True)
36
+ captions[i] = caption_text
37
+ # Return a single string with all captions
38
+ return '\n\n'.join(f'Model {i}: {caption}' for i, caption in captions.items())
39
+
40
+
41
+ # Define input and output components
42
+ input_component = gr.components.Image(label="Upload any Image", type="pil")
43
+ output_component = gr.components.Textbox(label="Captions")
44
 
45
  # Example images
46
  examples = [f"example{i}.JPG" for i in range(1, 10)]
47
+
48
+ # Interface
49
+ title = "IG-caption-generator"
50
+ description = "Made by: Jiayu Shi"
51
+ interface = gr.Interface(
52
+ fn=predict,
53
+ description=description,
54
+ inputs=input_component,
55
+ outputs=output_component,
56
+ examples=examples,
57
+ title=title,
58
+ )
59
+
60
+ # Launch interface
61
+ interface.launch(debug=True)
62
+