pcuenq HF staff commited on
Commit
35feaa0
1 Parent(s): 0909501

Screenshot text location: pad images

Browse files
Files changed (1) hide show
  1. app.py +30 -4
app.py CHANGED
@@ -15,6 +15,27 @@ processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=tokeni
15
  CAPTION_PROMPT = "Generate a coco-style caption.\n"
16
  DETAILED_CAPTION_PROMPT = "What is happening in this image?"
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def predict(image, prompt):
19
  # image = image.convert('RGB')
20
  model_inputs = processor(text=prompt, images=[image])
@@ -84,8 +105,13 @@ def coords_from_response(response):
84
  gr.Error("The string is malformed or does not match the expected pattern.")
85
 
86
  def localize(image, query):
87
- prompt= f"When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\n{query}"
88
- model_inputs = processor(text=prompt, images=[image])
 
 
 
 
 
89
  model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
90
 
91
  generation_output = model.generate(**model_inputs, max_new_tokens=40)
@@ -159,7 +185,6 @@ with gr.Blocks(css=css) as demo:
159
  vqa_btn.click(fn=predict, inputs=[image_input, text_input], outputs=vqa_output)
160
 
161
  with gr.Tab("Find Text in Screenshots"):
162
- gr.Markdown("This demo is designed to locate text in desktop screenshots. Please, ensure to upload images of 1920x1080 for best results!")
163
  with gr.Row():
164
  with gr.Column():
165
  localization_input = gr.Image(label="Upload your Image", type="pil")
@@ -170,7 +195,8 @@ with gr.Blocks(css=css) as demo:
170
  localization_output = gr.AnnotatedImage(label="Text Position")
171
 
172
  gr.Examples(
173
- [["assets/localization_example_1.jpeg", "Share your repair"]],
 
174
  inputs = [localization_input, query_input],
175
  outputs = [localization_output],
176
  fn=localize,
 
15
  CAPTION_PROMPT = "Generate a coco-style caption.\n"
16
  DETAILED_CAPTION_PROMPT = "What is happening in this image?"
17
 
18
+ def resize_to_max(image, max_width=1920, max_height=1080):
19
+ width, height = image.size
20
+ if width <= max_width and height <= max_height:
21
+ return image
22
+
23
+ scale = min(max_width/width, max_height/height)
24
+ width = int(width*scale)
25
+ height = int(height*scale)
26
+
27
+ return image.resize((width, height), Image.LANCZOS)
28
+
29
+ def pad_to_size(image, canvas_width=1920, canvas_height=1080):
30
+ width, height = image.size
31
+ if width >= canvas_width and height >= canvas_height:
32
+ return image
33
+
34
+ # Paste at (0, 0)
35
+ canvas = Image.new("RGB", (canvas_width, canvas_height))
36
+ canvas.paste(image)
37
+ return canvas
38
+
39
  def predict(image, prompt):
40
  # image = image.convert('RGB')
41
  model_inputs = processor(text=prompt, images=[image])
 
105
  gr.Error("The string is malformed or does not match the expected pattern.")
106
 
107
  def localize(image, query):
108
+ prompt = f"When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\n{query}"
109
+
110
+ # Downscale and/or pad to 1920x1080
111
+ padded = resize_to_max(image)
112
+ padded = pad_to_size(padded)
113
+
114
+ model_inputs = processor(text=prompt, images=[padded])
115
  model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
116
 
117
  generation_output = model.generate(**model_inputs, max_new_tokens=40)
 
185
  vqa_btn.click(fn=predict, inputs=[image_input, text_input], outputs=vqa_output)
186
 
187
  with gr.Tab("Find Text in Screenshots"):
 
188
  with gr.Row():
189
  with gr.Column():
190
  localization_input = gr.Image(label="Upload your Image", type="pil")
 
195
  localization_output = gr.AnnotatedImage(label="Text Position")
196
 
197
  gr.Examples(
198
+ [["assets/localization_example_1.jpeg", "Share your repair"],
199
+ ["assets/screen2words_ui_example.png", "statistics"]],
200
  inputs = [localization_input, query_input],
201
  outputs = [localization_output],
202
  fn=localize,