Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image, ImageDraw | |
from transformers import AutoProcessor | |
from modeling_florence2 import Florence2ForConditionalGeneration | |
import io | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
import numpy as np | |
import random | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = Florence2ForConditionalGeneration.from_pretrained("PleIAs/Florence-PDF", torch_dtype=torch_dtype, trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained("PleIAs/Florence-PDF", trust_remote_code=True) | |
TASK_PROMPTS = { | |
"Caption": "<CAPTION>", | |
"Detailed Caption": "<DETAILED_CAPTION>", | |
"More Detailed Caption": "<MORE_DETAILED_CAPTION>", | |
"Object Detection": "<OD>", | |
"Dense Region Caption": "<DENSE_REGION_CAPTION>", | |
"OCR": "<OCR>", | |
"OCR with Region": "<OCR_WITH_REGION>", | |
"Region Proposal": "<REGION_PROPOSAL>" | |
} | |
IMAGE_TASKS = ["Object Detection", "Dense Region Caption", "Region Proposal", "OCR with Region"] | |
TEXT_TASKS = ["Caption", "Detailed Caption", "More Detailed Caption", "OCR"] | |
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red', | |
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue'] | |
def fig_to_pil(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png') | |
buf.seek(0) | |
return Image.open(buf) | |
def plot_bbox(image, data): | |
fig, ax = plt.subplots() | |
ax.imshow(image) | |
for bbox, label in zip(data['bboxes'], data['labels']): | |
x1, y1, x2, y2 = bbox | |
rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none') | |
ax.add_patch(rect) | |
plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5)) | |
ax.axis('off') | |
return fig | |
def draw_ocr_bboxes(image, prediction): | |
scale = 1 | |
draw = ImageDraw.Draw(image) | |
bboxes, labels = prediction['quad_boxes'], prediction['labels'] | |
for box, label in zip(bboxes, labels): | |
color = random.choice(colormap) | |
new_box = (np.array(box) * scale).tolist() | |
draw.polygon(new_box, width=3, outline=color) | |
draw.text((new_box[0]+8, new_box[1]+2), | |
"{}".format(label), | |
align="right", | |
fill=color) | |
return image | |
def process_image(image, task): | |
prompt = TASK_PROMPTS[task] | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
num_beams=3, | |
do_sample=False | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) | |
return parsed_answer | |
def main_process(image, task): | |
result = process_image(image, task) | |
if task in IMAGE_TASKS: | |
if task == "OCR with Region": | |
output_image = draw_ocr_bboxes(image.copy(), result[TASK_PROMPTS[task]]) | |
else: | |
fig = plot_bbox(image, result[TASK_PROMPTS[task]]) | |
output_image = fig_to_pil(fig) | |
return output_image, gr.update(visible=True), None, gr.update(visible=False) | |
else: | |
return None, gr.update(visible=False), str(result), gr.update(visible=True) | |
def reset_outputs(): | |
return None, gr.update(visible=False), None, gr.update(visible=True) | |
with gr.Blocks(title="Florence-2 Demo") as iface: | |
gr.Markdown("# Florence-2 Demo") | |
gr.Markdown("Upload an image and select a task to process with Florence-2.") | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Input Image") | |
task_dropdown = gr.Dropdown(list(TASK_PROMPTS.keys()), label="Task", value="Caption") | |
with gr.Row(): | |
submit_button = gr.Button("Process") | |
reset_button = gr.Button("Reset") | |
output_image = gr.Image(label="Processed Image", visible=False) | |
output_text = gr.Textbox(label="Output", visible=True) | |
def process_and_update(image, task): | |
if image is None: | |
return None, gr.update(visible=False), "Please upload an image first.", gr.update(visible=True) | |
return main_process(image, task) | |
submit_button.click( | |
fn=process_and_update, | |
inputs=[image_input, task_dropdown], | |
outputs=[output_image, output_image, output_text, output_text] | |
) | |
reset_button.click( | |
fn=reset_outputs, | |
inputs=[], | |
outputs=[output_image, output_image, output_text, output_text] | |
) | |
task_dropdown.change( | |
fn=lambda task: (gr.update(visible=task in IMAGE_TASKS), gr.update(visible=task in TEXT_TASKS)), | |
inputs=[task_dropdown], | |
outputs=[output_image, output_text] | |
) | |
iface.launch() |