Spaces:
Sleeping
Sleeping
File size: 6,379 Bytes
ade70cf 66c92ba d69fd19 ef3da92 39ae23a 70166d4 3f06e75 6690643 e357537 ade70cf 70166d4 ade70cf 3f06e75 d2e31b4 6690643 3f06e75 66c92ba 3f06e75 66c92ba 3f06e75 d256f3b 483be0d 833928a 70166d4 d256f3b beec895 70166d4 54fb483 70166d4 ade70cf 70166d4 ade70cf c8f76e0 70166d4 6172e67 833928a 6172e67 70166d4 1d51385 6172e67 70166d4 6172e67 70166d4 6172e67 8b2d7f4 70166d4 6172e67 833928a 6172e67 70166d4 6172e67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM
import spaces
from PIL import Image
import subprocess
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import requests
from io import BytesIO
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import os
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_dir = 'medieval-data/florence2-medieval-bbox-line-detection'
model_dir = "medieval-data/florence2-medieval-bbox-zone-detection"
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
"""Work around for https://huggingface.co./microsoft/phi-1_5/discussions/72."""
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
# Load the configuration
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
# Modify the vision configuration
if hasattr(config, 'vision_config'):
config.vision_config.model_type = 'davit'
print("Modified vision configuration:")
print(config.vision_config)
# Try to load the model with the modified configuration
try:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
config=config,
trust_remote_code=True
)
print("Model loaded successfully!")
except Exception as e:
print(f"Failed to load model: {str(e)}")
# Load the processor without specifying a revision
try:
processor = AutoProcessor.from_pretrained(
model_dir,
trust_remote_code=True
)
print("Processor loaded successfully!")
except Exception as e:
print(f"Failed to load processor: {str(e)}")
TITLE = "# [Florence-2-DocVQA Demo](https://huggingface.co./HuggingFaceM4/Florence-2-DocVQA)"
DESCRIPTION = "The demo for Florence-2 fine-tuned on DocVQA dataset. You can find the notebook [here](https://colab.research.google.com/drive/1hKDrJ5AH_o7I95PtZ9__VlCTNAo1Gjpf?usp=sharing). Read more about Florence-2 fine-tuning [here](finetune-florence2)."
# Define a color map for different labels
colormap = plt.cm.get_cmap('tab20')
@spaces.GPU
def process_image(image, text_input=None):
max_size = 1000
prompt = "<OD>"
# Calculate the scaling factor
original_width, original_height = image.size
scale = min(max_size / original_width, max_size / original_height)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# Resize the image
image = image.resize((new_width, new_height))
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
result = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
return result, image
def visualize_bboxes(result, image):
fig, ax = plt.subplots(1, figsize=(15, 15))
ax.imshow(image)
# Create a set of unique labels
unique_labels = set(result['<OD>']['labels'])
# Create a dictionary to map labels to colors
color_dict = {label: colormap(i/len(unique_labels)) for i, label in enumerate(unique_labels)}
# Add bounding boxes and labels to the plot
for bbox, label in zip(result['<OD>']['bboxes'], result['<OD>']['labels']):
x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor=color_dict[label], facecolor='none')
ax.add_patch(rect)
plt.text(x, y, label, fontsize=12, bbox=dict(facecolor=color_dict[label], alpha=0.5))
plt.axis('off')
return fig
def run_example(image, text_input=None):
if isinstance(image, str): # If image is a URL
response = requests.get(image)
image = Image.open(BytesIO(response.content))
elif isinstance(image, np.ndarray): # If image is a numpy array
image = Image.fromarray(image)
result, processed_image = process_image(image, text_input)
fig = visualize_bboxes(result, processed_image)
# Convert matplotlib figure to image
img_buf = BytesIO()
fig.savefig(img_buf, format='png')
img_buf.seek(0)
output_image = Image.open(img_buf)
return output_image
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(TITLE)
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Florence-2 Image Processing"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
text_input = gr.Textbox(label="Text Input (optional)")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_img = gr.Image(label="Output Image with Bounding Boxes")
gr.Examples(
examples=[
["https://huggingface.co./datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/london-british-library-egerton-821/page-002-of-004.jpg", None],
["https://huggingface.co./datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/paris-bnf-lat-12449/page-002-of-003.jpg", None],
["https://huggingface.co./datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/paris-bnf-nal-1909/page-009-of-012.jpg", None],
["https://huggingface.co./datasets/CATMuS/medieval-segmentation/resolve/main/data/test/paris-bnf-fr-574/page-001-of-003.jpg", None]
],
inputs=[input_img, text_input],
outputs=[output_img],
fn=run_example,
cache_examples=True,
label='Try the examples below'
)
submit_btn.click(run_example, [input_img, text_input], [output_img])
demo.launch(debug=True) |