florence-pdf / app.py
Tonic's picture
add interface logic
3bdebf9 unverified
raw
history blame
4.96 kB
import gradio as gr
import torch
from PIL import Image, ImageDraw
from transformers import AutoProcessor
from modeling_florence2 import Florence2ForConditionalGeneration
import io
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import random
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = Florence2ForConditionalGeneration.from_pretrained("PleIAs/Florence-PDF", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("PleIAs/Florence-PDF", trust_remote_code=True)
TASK_PROMPTS = {
"Caption": "<CAPTION>",
"Detailed Caption": "<DETAILED_CAPTION>",
"More Detailed Caption": "<MORE_DETAILED_CAPTION>",
"Object Detection": "<OD>",
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
"OCR": "<OCR>",
"OCR with Region": "<OCR_WITH_REGION>",
"Region Proposal": "<REGION_PROPOSAL>"
}
IMAGE_TASKS = ["Object Detection", "Dense Region Caption", "Region Proposal", "OCR with Region"]
TEXT_TASKS = ["Caption", "Detailed Caption", "More Detailed Caption", "OCR"]
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
def fig_to_pil(fig):
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
return Image.open(buf)
def plot_bbox(image, data):
fig, ax = plt.subplots()
ax.imshow(image)
for bbox, label in zip(data['bboxes'], data['labels']):
x1, y1, x2, y2 = bbox
rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
ax.axis('off')
return fig
def draw_ocr_bboxes(image, prediction):
scale = 1
draw = ImageDraw.Draw(image)
bboxes, labels = prediction['quad_boxes'], prediction['labels']
for box, label in zip(bboxes, labels):
color = random.choice(colormap)
new_box = (np.array(box) * scale).tolist()
draw.polygon(new_box, width=3, outline=color)
draw.text((new_box[0]+8, new_box[1]+2),
"{}".format(label),
align="right",
fill=color)
return image
def process_image(image, task):
prompt = TASK_PROMPTS[task]
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
return parsed_answer
def main_process(image, task):
result = process_image(image, task)
if task in IMAGE_TASKS:
if task == "OCR with Region":
output_image = draw_ocr_bboxes(image.copy(), result[TASK_PROMPTS[task]])
else:
fig = plot_bbox(image, result[TASK_PROMPTS[task]])
output_image = fig_to_pil(fig)
return output_image, gr.update(visible=True), None, gr.update(visible=False)
else:
return None, gr.update(visible=False), str(result), gr.update(visible=True)
def reset_outputs():
return None, gr.update(visible=False), None, gr.update(visible=True)
with gr.Blocks(title="Florence-2 Demo") as iface:
gr.Markdown("# Florence-2 Demo")
gr.Markdown("Upload an image and select a task to process with Florence-2.")
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
task_dropdown = gr.Dropdown(list(TASK_PROMPTS.keys()), label="Task", value="Caption")
with gr.Row():
submit_button = gr.Button("Process")
reset_button = gr.Button("Reset")
output_image = gr.Image(label="Processed Image", visible=False)
output_text = gr.Textbox(label="Output", visible=True)
def process_and_update(image, task):
if image is None:
return None, gr.update(visible=False), "Please upload an image first.", gr.update(visible=True)
return main_process(image, task)
submit_button.click(
fn=process_and_update,
inputs=[image_input, task_dropdown],
outputs=[output_image, output_image, output_text, output_text]
)
reset_button.click(
fn=reset_outputs,
inputs=[],
outputs=[output_image, output_image, output_text, output_text]
)
task_dropdown.change(
fn=lambda task: (gr.update(visible=task in IMAGE_TASKS), gr.update(visible=task in TEXT_TASKS)),
inputs=[task_dropdown],
outputs=[output_image, output_text]
)
iface.launch()