Tonic commited on
Commit
e528291
·
unverified ·
1 Parent(s): 6b93795

add quad boxes

Browse files
Files changed (2) hide show
  1. app.py +32 -16
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image, ImageDraw
 
 
4
  from transformers import AutoProcessor
5
  from modeling_florence2 import Florence2ForConditionalGeneration
6
  import io
@@ -115,14 +117,15 @@ def fig_to_pil(fig):
115
  def plot_bbox(image, data, use_quad_boxes=False):
116
  fig, ax = plt.subplots()
117
  ax.imshow(image)
118
-
119
  # Handle both 'bboxes' and 'quad_boxes'
120
  if use_quad_boxes:
121
  for quad_box, label in zip(data.get('quad_boxes', []), data.get('labels', [])):
122
  quad_box = np.array(quad_box).reshape(-1, 2)
123
- poly = patches.Polygon(quad_box, linewidth=1, edgecolor='r', facecolor='none')
124
  ax.add_patch(poly)
125
- plt.text(quad_box[0][0], quad_box[0][1], label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
 
126
  else:
127
  bboxes = data.get('bboxes', [])
128
  labels = data.get('labels', [])
@@ -149,49 +152,60 @@ def draw_ocr_bboxes(image, prediction):
149
  fill=color)
150
  return image
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def process_image(image, task):
153
  prompt = TASK_PROMPTS[task]
154
-
155
  # Print the inputs for debugging
156
  print(f"\n--- Processing Task: {task} ---")
157
  print(f"Prompt: {prompt}")
158
-
159
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
160
-
161
  # Print the input tensors for debugging
162
  print(f"Model Input: {inputs}")
163
-
164
  generated_ids = model.generate(
165
  **inputs,
166
  max_new_tokens=1024,
167
  num_beams=3,
168
  do_sample=False
169
  )
170
-
171
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
172
-
173
  # Print the raw generated output for debugging
174
  print(f"Raw Model Output: {generated_text}")
175
-
176
  parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
177
-
178
  # Print the parsed answer for debugging
179
  print(f"Parsed Answer: {parsed_answer}")
180
-
181
  return parsed_answer
182
 
 
183
  def main_process(image, task):
184
  result = process_image(image, task)
185
 
186
  if task in IMAGE_TASKS:
187
- if task == "OCR with Region":
188
  fig = plot_bbox(image, result.get('<OCR_WITH_REGION>', {}), use_quad_boxes=True)
189
  output_image = fig_to_pil(fig)
190
  text_output = result.get('<OCR_WITH_REGION>', {}).get('recognized_text', 'No text found')
191
-
192
  # Debugging: Print the recognized text
193
  print(f"Recognized Text: {text_output}")
194
-
195
  return output_image, gr.update(visible=True), text_output, gr.update(visible=True)
196
  else:
197
  fig = plot_bbox(image, result.get(TASK_PROMPTS[task], {}))
@@ -219,6 +233,8 @@ with gr.Blocks(title="PLeIAs/📸📈✍🏻Florence-PDF") as iface:
219
  output_image = gr.Image(label="PLeIAs/📸📈✍🏻Florence-PDF", visible=False)
220
  output_text = gr.Textbox(label="PLeIAs/📸📈✍🏻Florence-PDF", visible=True)
221
 
 
 
222
  def process_and_update(image, task):
223
  if image is None:
224
  return None, gr.update(visible=False), "Please upload an image first.", gr.update(visible=True)
 
1
  import gradio as gr
2
  import torch
3
+ import cv2
4
+ import numpy as np
5
+ from matplotlib import pyplot as pltfrom PIL import Image, ImageDraw
6
  from transformers import AutoProcessor
7
  from modeling_florence2 import Florence2ForConditionalGeneration
8
  import io
 
117
  def plot_bbox(image, data, use_quad_boxes=False):
118
  fig, ax = plt.subplots()
119
  ax.imshow(image)
120
+
121
  # Handle both 'bboxes' and 'quad_boxes'
122
  if use_quad_boxes:
123
  for quad_box, label in zip(data.get('quad_boxes', []), data.get('labels', [])):
124
  quad_box = np.array(quad_box).reshape(-1, 2)
125
+ poly = Polygon(quad_box, linewidth=1, edgecolor='r', facecolor='none')
126
  ax.add_patch(poly)
127
+ plt.text(quad_box[0][0], quad_box[0][1], label, color='white', fontsize=8,
128
+ bbox=dict(facecolor='red', alpha=0.5))
129
  else:
130
  bboxes = data.get('bboxes', [])
131
  labels = data.get('labels', [])
 
152
  fill=color)
153
  return image
154
 
155
+ def draw_bounding_boxes(image, quad_boxes, labels, color=(0, 255, 0), thickness=2):
156
+ """
157
+ Draws quadrilateral bounding boxes on the image.
158
+
159
+ Args:
160
+ image: The original image where the bounding boxes will be drawn.
161
+ quad_boxes: List of quadrilateral bounding box points. Each bounding box contains four points.
162
+ labels: List of labels corresponding to each bounding box.
163
+ color: Color of the bounding box. Default is green.
164
+ thickness: Thickness of the bounding box lines. Default is 2.
165
+ """
166
+ for i, quad in enumerate(quad_boxes):
167
+ points = np.array(quad, dtype=np.int32).reshape((-1, 1, 2)) # Reshape the quad points for drawing
168
+ image = cv2.polylines(image, [points], isClosed=True, color=color, thickness=thickness)
169
+ # Add label text near the top-left point of the bounding box
170
+ label_pos = (int(quad[0]), int(quad[1]) - 10) # Positioning label slightly above the bounding box
171
+ cv2.putText(image, labels[i], label_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness)
172
+
173
+ return image
174
+
175
  def process_image(image, task):
176
  prompt = TASK_PROMPTS[task]
 
177
  # Print the inputs for debugging
178
  print(f"\n--- Processing Task: {task} ---")
179
  print(f"Prompt: {prompt}")
 
180
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
 
181
  # Print the input tensors for debugging
182
  print(f"Model Input: {inputs}")
 
183
  generated_ids = model.generate(
184
  **inputs,
185
  max_new_tokens=1024,
186
  num_beams=3,
187
  do_sample=False
188
  )
189
+
190
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
191
  # Print the raw generated output for debugging
192
  print(f"Raw Model Output: {generated_text}")
 
193
  parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
 
194
  # Print the parsed answer for debugging
195
  print(f"Parsed Answer: {parsed_answer}")
 
196
  return parsed_answer
197
 
198
+
199
  def main_process(image, task):
200
  result = process_image(image, task)
201
 
202
  if task in IMAGE_TASKS:
203
+ if task == "📸✍🏻OCR with Region":
204
  fig = plot_bbox(image, result.get('<OCR_WITH_REGION>', {}), use_quad_boxes=True)
205
  output_image = fig_to_pil(fig)
206
  text_output = result.get('<OCR_WITH_REGION>', {}).get('recognized_text', 'No text found')
 
207
  # Debugging: Print the recognized text
208
  print(f"Recognized Text: {text_output}")
 
209
  return output_image, gr.update(visible=True), text_output, gr.update(visible=True)
210
  else:
211
  fig = plot_bbox(image, result.get(TASK_PROMPTS[task], {}))
 
233
  output_image = gr.Image(label="PLeIAs/📸📈✍🏻Florence-PDF", visible=False)
234
  output_text = gr.Textbox(label="PLeIAs/📸📈✍🏻Florence-PDF", visible=True)
235
 
236
+ gr.Markdown(model_presentation)
237
+
238
  def process_and_update(image, task):
239
  if image is None:
240
  return None, gr.update(visible=False), "Please upload an image first.", gr.update(visible=True)
requirements.txt CHANGED
@@ -3,4 +3,5 @@ transformers
3
  accelerate
4
  pillow
5
  einops
6
- timm
 
 
3
  accelerate
4
  pillow
5
  einops
6
+ timm
7
+ opencv-python