File size: 6,379 Bytes
ade70cf
66c92ba
d69fd19
ef3da92
39ae23a
70166d4
 
 
 
 
3f06e75
6690643
e357537
ade70cf
70166d4
ade70cf
3f06e75
d2e31b4
6690643
 
 
 
 
 
 
 
 
3f06e75
 
 
 
66c92ba
3f06e75
 
 
 
 
66c92ba
3f06e75
 
 
 
 
 
 
 
 
 
d256f3b
483be0d
 
 
 
 
 
 
 
 
833928a
 
 
70166d4
 
d256f3b
beec895
70166d4
 
 
 
 
 
 
 
 
 
 
 
 
54fb483
70166d4
ade70cf
 
 
 
 
70166d4
ade70cf
c8f76e0
70166d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6172e67
 
 
 
 
 
 
 
 
 
833928a
6172e67
70166d4
1d51385
6172e67
 
 
 
 
70166d4
6172e67
 
70166d4
 
 
 
6172e67
8b2d7f4
70166d4
 
6172e67
833928a
6172e67
70166d4
6172e67
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM
import spaces
from PIL import Image 
import subprocess
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import requests
from io import BytesIO
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import os

subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

model_dir = 'medieval-data/florence2-medieval-bbox-line-detection'
model_dir = "medieval-data/florence2-medieval-bbox-zone-detection"

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    """Work around for https://huggingface.co./microsoft/phi-1_5/discussions/72."""
    if not str(filename).endswith("/modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports
    
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
    # Load the configuration
    config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
    
    # Modify the vision configuration
    if hasattr(config, 'vision_config'):
        config.vision_config.model_type = 'davit'
    
    print("Modified vision configuration:")
    print(config.vision_config)
    
    # Try to load the model with the modified configuration
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            config=config,
            trust_remote_code=True
        )
        print("Model loaded successfully!")
    except Exception as e:
        print(f"Failed to load model: {str(e)}")

    # Load the processor without specifying a revision
    try:
        processor = AutoProcessor.from_pretrained(
            model_dir, 
            trust_remote_code=True
        )
        print("Processor loaded successfully!")
    except Exception as e:
        print(f"Failed to load processor: {str(e)}")
TITLE = "# [Florence-2-DocVQA Demo](https://huggingface.co./HuggingFaceM4/Florence-2-DocVQA)"
DESCRIPTION = "The demo for Florence-2 fine-tuned on DocVQA dataset. You can find the notebook [here](https://colab.research.google.com/drive/1hKDrJ5AH_o7I95PtZ9__VlCTNAo1Gjpf?usp=sharing). Read more about Florence-2 fine-tuning [here](finetune-florence2)."

# Define a color map for different labels
colormap = plt.cm.get_cmap('tab20')

@spaces.GPU
def process_image(image, text_input=None):
    max_size = 1000
    prompt = "<OD>"

    # Calculate the scaling factor
    original_width, original_height = image.size
    scale = min(max_size / original_width, max_size / original_height)
    new_width = int(original_width * scale)
    new_height = int(original_height * scale)

    # Resize the image
    image = image.resize((new_width, new_height))

    inputs = processor(text=prompt, images=image, return_tensors="pt")

    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        do_sample=False,
        num_beams=3
    )

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    result = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
    
    return result, image

def visualize_bboxes(result, image):
    fig, ax = plt.subplots(1, figsize=(15, 15))
    ax.imshow(image)

    # Create a set of unique labels
    unique_labels = set(result['<OD>']['labels'])

    # Create a dictionary to map labels to colors
    color_dict = {label: colormap(i/len(unique_labels)) for i, label in enumerate(unique_labels)}

    # Add bounding boxes and labels to the plot
    for bbox, label in zip(result['<OD>']['bboxes'], result['<OD>']['labels']):
        x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
        rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor=color_dict[label], facecolor='none')
        ax.add_patch(rect)
        plt.text(x, y, label, fontsize=12, bbox=dict(facecolor=color_dict[label], alpha=0.5))

    plt.axis('off')
    return fig

def run_example(image, text_input=None):
    if isinstance(image, str):  # If image is a URL
        response = requests.get(image)
        image = Image.open(BytesIO(response.content))
    elif isinstance(image, np.ndarray):  # If image is a numpy array
        image = Image.fromarray(image)
    
    result, processed_image = process_image(image, text_input)
    fig = visualize_bboxes(result, processed_image)
    
    # Convert matplotlib figure to image
    img_buf = BytesIO()
    fig.savefig(img_buf, format='png')
    img_buf.seek(0)
    output_image = Image.open(img_buf)
    
    return output_image

css = """
  #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(TITLE)
    gr.Markdown(DESCRIPTION)
    with gr.Tab(label="Florence-2 Image Processing"):
        with gr.Row():
            with gr.Column():
                input_img = gr.Image(label="Input Picture")
                text_input = gr.Textbox(label="Text Input (optional)")
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                output_img = gr.Image(label="Output Image with Bounding Boxes")
        gr.Examples(
            examples=[
                ["https://huggingface.co./datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/london-british-library-egerton-821/page-002-of-004.jpg", None],
                ["https://huggingface.co./datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/paris-bnf-lat-12449/page-002-of-003.jpg", None],
                ["https://huggingface.co./datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/paris-bnf-nal-1909/page-009-of-012.jpg", None],
                ["https://huggingface.co./datasets/CATMuS/medieval-segmentation/resolve/main/data/test/paris-bnf-fr-574/page-001-of-003.jpg", None]
            ],
            inputs=[input_img, text_input],
            outputs=[output_img],
            fn=run_example,
            cache_examples=True,
            label='Try the examples below'
        )
        submit_btn.click(run_example, [input_img, text_input], [output_img])

demo.launch(debug=True)