pcuenq HF staff commited on
Commit
d9a4d76
1 Parent(s): 5cc174c

Workaround for scaling bug in transformers

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
- import os
3
  import torch
4
  from transformers import FuyuForCausalLM, AutoTokenizer
5
  from transformers.models.fuyu.processing_fuyu import FuyuProcessor
6
  from transformers.models.fuyu.image_processing_fuyu import FuyuImageProcessor
 
7
 
8
  model_id = "adept/fuyu-8b"
9
  revision = "refs/pr/3"
@@ -16,8 +16,21 @@ processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=tokeni
16
 
17
  caption_prompt = "Generate a coco-style caption.\\n"
18
 
 
 
 
 
 
 
 
 
 
 
 
19
  def predict(image, prompt):
20
  # image = image.convert('RGB')
 
 
21
  model_inputs = processor(text=prompt, images=[image])
22
  model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
23
 
@@ -57,7 +70,7 @@ with gr.Blocks(css=css) as demo:
57
  with gr.Tab("Visual Question Answering"):
58
  with gr.Row():
59
  with gr.Column():
60
- image_input = gr.Image(label="Upload your Image")
61
  text_input = gr.Textbox(label="Ask a Question")
62
  vqa_output = gr.Textbox(label="Output")
63
 
@@ -75,7 +88,7 @@ with gr.Blocks(css=css) as demo:
75
 
76
  with gr.Tab("Image Captioning"):
77
  with gr.Row():
78
- captioning_input = gr.Image(label="Upload your Image")
79
  captioning_output = gr.Textbox(label="Output")
80
  captioning_btn = gr.Button("Generate Caption")
81
 
 
1
  import gradio as gr
 
2
  import torch
3
  from transformers import FuyuForCausalLM, AutoTokenizer
4
  from transformers.models.fuyu.processing_fuyu import FuyuProcessor
5
  from transformers.models.fuyu.image_processing_fuyu import FuyuImageProcessor
6
+ from PIL import Image
7
 
8
  model_id = "adept/fuyu-8b"
9
  revision = "refs/pr/3"
 
16
 
17
  caption_prompt = "Generate a coco-style caption.\\n"
18
 
19
+ def resize_to_max(image, max_width=1920, max_height=1080):
20
+ width, height = image.size
21
+ if width <= max_width and height <= max_height:
22
+ return image
23
+
24
+ scale = min(max_width/width, max_height/height)
25
+ width = int(width*scale)
26
+ height = int(height*scale)
27
+
28
+ return image.resize((width, height), Image.LANCZOS)
29
+
30
  def predict(image, prompt):
31
  # image = image.convert('RGB')
32
+ image = resize_to_max(image)
33
+
34
  model_inputs = processor(text=prompt, images=[image])
35
  model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
36
 
 
70
  with gr.Tab("Visual Question Answering"):
71
  with gr.Row():
72
  with gr.Column():
73
+ image_input = gr.Image(label="Upload your Image", type="pil")
74
  text_input = gr.Textbox(label="Ask a Question")
75
  vqa_output = gr.Textbox(label="Output")
76
 
 
88
 
89
  with gr.Tab("Image Captioning"):
90
  with gr.Row():
91
+ captioning_input = gr.Image(label="Upload your Image", type="pil")
92
  captioning_output = gr.Textbox(label="Output")
93
  captioning_btn = gr.Button("Generate Caption")
94