Stoneman commited on
Commit
ed7be52
1 Parent(s): d7e34bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -8
app.py CHANGED
@@ -4,18 +4,42 @@ from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDeco
4
 
5
  # Setup device, model, tokenizer, and feature extractor
6
  device = 'cpu'
7
- model_checkpoint = "Stoneman/IG-caption-generator-nlpconnect-all"
8
- feature_extractor = ViTImageProcessor.from_pretrained(model_checkpoint)
9
- tokenizer = GPT2TokenizerFast.from_pretrained(model_checkpoint)
10
- model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Prediction function
13
  def predict(image, max_length=128):
 
 
14
  image = image.convert('RGB')
15
- pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
16
- caption_ids = model.generate(pixel_values, max_length=max_length)[0]
17
- caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
18
- return caption_text
 
 
 
 
19
 
20
  # Define input and output components
21
  input_component = gr.components.Image(label="Upload any Image", type="pil")
 
4
 
5
  # Setup device, model, tokenizer, and feature extractor
6
  device = 'cpu'
7
+
8
+
9
+ model_checkpoint1 = "Stoneman/IG-caption-generator-vit-gpt2-last-block"
10
+ feature_extractor1 = ViTImageProcessor.from_pretrained(model_checkpoint1)
11
+ tokenizer1 = GPT2TokenizerFast.from_pretrained(model_checkpoint1)
12
+ model1 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint1).to(device)
13
+
14
+ model_checkpoint2 = "Stoneman/IG-caption-generator-vit-gpt2-all"
15
+ model2 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint2).to(device)
16
+
17
+ model_checkpoint3 = "Stoneman/IG-caption-generator-nlpconnect-last-block"
18
+ model3 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint3).to(device)
19
+
20
+ model_checkpoint4 = "Stoneman/IG-caption-generator-nlpconnect-all"
21
+ model4 = VisionEncoderDecoderModel.from_pretrained(model_checkpoint4).to(device)
22
+
23
+ models = {
24
+ 1: model1,
25
+ 2: model2,
26
+ 3: model3,
27
+ 4: model4
28
+ }
29
 
30
  # Prediction function
31
  def predict(image, max_length=128):
32
+ captions = {}
33
+
34
  image = image.convert('RGB')
35
+ pixel_values = feature_extractor1(images=image, return_tensors="pt").pixel_values.to(device)
36
+ for i in range(1,5):
37
+ caption_ids = models[i].generate(pixel_values, max_length=max_length)[0]
38
+ caption_text = tokenizer1.decode(caption_ids, skip_special_tokens=True)
39
+ captions[i] = caption_text
40
+ # Return a single string with all captions
41
+ return '\n\n'.join(f'Model {i}: {caption}' for i, caption in captions.items())
42
+
43
 
44
  # Define input and output components
45
  input_component = gr.components.Image(label="Upload any Image", type="pil")