SZhanZ commited on
Commit
4b25987
Β·
1 Parent(s): 6c162e9

init commit

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. README.md +0 -0
  3. app.py +95 -55
  4. examples/image1.jpg +3 -0
  5. examples/image2.jpg +0 -0
  6. requirements.txt +5 -1
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/image1.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
File without changes
app.py CHANGED
@@ -1,64 +1,104 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
8
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
4
+ from PIL import Image, ImageDraw
5
 
6
+ def draw_bbox(image, bbox):
7
+ x1, y1, x2, y2 = bbox
8
+ draw = ImageDraw.Draw(image)
9
+ draw.rectangle((x1, y1, x2, y2), outline="red", width=5)
10
+ return image
11
 
12
+ def extract_bbox_answer(content):
13
+ bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}'
14
+ bbox_match = re.search(bbox_pattern, content)
15
+ if bbox_match:
16
+ bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
17
+ return bbox
18
+ return [0, 0, 0, 0]
19
 
20
+ def process_image_and_text(image, text):
21
+ """Process image and text input, return thinking process and bbox"""
22
+ question = f"Please provide the bounding box coordinate of the region this sentence describes: {text}."
23
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
24
+
25
+ messages = [
26
+ {
27
+ "role": "user",
28
+ "content": [
29
+ {"type": "image"},
30
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=question)},
31
+ ],
32
+ }
33
+ ]
34
+
35
+ text = processor.apply_chat_template(
36
+ messages, tokenize=False, add_generation_prompt=True
37
+ )
38
 
39
+ inputs = processor(
40
+ text=[text],
41
+ images=image,
42
+ return_tensors="pt",
43
+ padding=True,
44
+ padding_side="left",
45
+ add_special_tokens=False,
46
+ )
47
 
48
+ # inputs = inputs
49
 
50
+ with torch.no_grad():
51
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
52
+ generated_ids_trimmed = [
53
+ out_ids[len(inputs.input_ids[0]):] for out_ids in generated_ids
54
+ ]
55
+
56
+ output_text = processor.batch_decode(
57
+ generated_ids_trimmed, skip_special_tokens=True
58
+ )[0]
59
+ print("output_text: ", output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Extract thinking process
62
+ think_match = re.search(r'<think>(.*?)</think>', output_text, re.DOTALL)
63
+ thinking_process = think_match.group(1).strip() if think_match else "No thinking process found"
64
+
65
+ # Get bbox and draw
66
+ bbox = extract_bbox_answer(output_text)
67
+
68
+ # Draw bbox on the image
69
+ result_image = image.copy()
70
+ result_image = draw_bbox(result_image, bbox)
71
+
72
+ return thinking_process, result_image
73
 
74
  if __name__ == "__main__":
75
+ import gradio as gr
76
+
77
+ # model_path = "/data/shz/project/vlm-r1/VLM-R1/output/Qwen2.5-VL-3B-GRPO-REC/checkpoint-500"
78
+ model_path = "SZhanZ/Qwen2.5VL-VLM-R1-REC-step500"
79
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path)
80
+ processor = AutoProcessor.from_pretrained(model_path)
81
+
82
+ def gradio_interface(image, text):
83
+ thinking, result_image = process_image_and_text(image, text)
84
+ return thinking, result_image
85
+
86
+ demo = gr.Interface(
87
+ fn=gradio_interface,
88
+ inputs=[
89
+ gr.Image(type="pil", label="Input Image"),
90
+ gr.Textbox(label="Description Text")
91
+ ],
92
+ outputs=[
93
+ gr.Textbox(label="Thinking Process"),
94
+ gr.Image(type="pil", label="Result with Bbox")
95
+ ],
96
+ title="Visual Referring Expression Demo",
97
+ description="Upload an image and input description text, the system will return the thinking process and region annotation",
98
+ examples=[
99
+ ["examples/image1.jpg", "food with the highest protein"],
100
+ ["examples/image2.jpg", "the cheapest laptop"],
101
+ ]
102
+ )
103
+
104
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
examples/image1.jpg ADDED

Git LFS Details

  • SHA256: e779913142b5db662be50e6e5e8d9b598913dc3a1c2c27abfbbd1dd44630cdd9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
examples/image2.jpg ADDED
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
1
+ torch>=2.0.0
2
+ git+https://github.com/huggingface/transformers
3
+ Pillow>=10.0.0
4
+ httpx[socks]
5
+ accelerate>=0.26.0