File size: 3,963 Bytes
4b25987
 
 
 
5022732
4b25987
 
 
 
 
5022732
4b25987
 
 
 
 
 
 
5022732
4b25987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5022732
4b25987
 
 
 
 
 
 
 
5022732
3d71402
5022732
4b25987
 
 
 
 
 
 
 
 
 
5022732
4b25987
 
 
 
 
 
 
 
 
 
 
 
5022732
 
4b25987
 
 
 
3d71402
4b25987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5c1e8d
4b25987
802e27c
 
f5c1e8d
8e743cf
 
 
4b25987
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import re
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from PIL import Image, ImageDraw

def draw_bbox(image, bbox):
    x1, y1, x2, y2 = bbox
    draw = ImageDraw.Draw(image)
    draw.rectangle((x1, y1, x2, y2), outline="red", width=5)
    return image

def extract_bbox_answer(content):
    bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}'
    bbox_match = re.search(bbox_pattern, content)
    if bbox_match:
        bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
        return bbox
    return [0, 0, 0, 0]

def process_image_and_text(image, text):
    """Process image and text input, return thinking process and bbox"""
    question = f"Please provide the bounding box coordinate of the region this sentence describes: {text}."
    QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
    
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": QUESTION_TEMPLATE.format(Question=question)},
            ],
        }
    ]
    
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    inputs = processor(
        text=[text],
        images=image,
        return_tensors="pt",
        padding=True,
        padding_side="left",
        add_special_tokens=False,
    )

    inputs = inputs.to("cuda")

    with torch.no_grad():
        generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
        generated_ids_trimmed = [
            out_ids[len(inputs.input_ids[0]):] for out_ids in generated_ids
        ]
    
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True
    )[0]
    print("output_text: ", output_text)

    # Extract thinking process
    think_match = re.search(r'<think>(.*?)</think>', output_text, re.DOTALL)
    thinking_process = think_match.group(1).strip() if think_match else "No thinking process found"
    
    # Get bbox and draw
    bbox = extract_bbox_answer(output_text)
    
    # Draw bbox on the image
    result_image = image.copy()
    result_image = draw_bbox(result_image, bbox)
    
    return thinking_process, result_image

if __name__ == "__main__":
    import gradio as gr
    
    # model_path = "/data/shz/project/vlm-r1/VLM-R1/output/Qwen2.5-VL-3B-GRPO-REC/checkpoint-500"
    model_path = "SZhanZ/Qwen2.5VL-VLM-R1-REC-step500"
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, device_map="cuda")
    processor = AutoProcessor.from_pretrained(model_path)
    
    def gradio_interface(image, text):
        thinking, result_image = process_image_and_text(image, text)
        return thinking, result_image
    
    demo = gr.Interface(
        fn=gradio_interface,
        inputs=[
            gr.Image(type="pil", label="Input Image"),
            gr.Textbox(label="Description Text")
        ],
        outputs=[
            gr.Textbox(label="Thinking Process"),
            gr.Image(type="pil", label="Result with Bbox")
        ],
        title="Visual Referring Expression Demo",
        description="Upload an image and input description text, the system will return the thinking process and region annotation. \n\nOur GitHub: [VLM-R1](https://github.com/om-ai-lab/VLM-R1/tree/main)",
        examples=[
            ["examples/image1.jpg", "person with blue shirt"],
            ["examples/image2.jpg", "food with the highest protein"],
            ["examples/image3.jpg", "the cheapest Apple laptop"],
        ],
        cache_examples=False,
        examples_per_page=10
    )
    
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)