Ayush0804 commited on
Commit
84c45b1
·
verified ·
1 Parent(s): 4ab5238

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +204 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM,AutoProcessor,Qwen2VLForConditionalGeneration
3
+ from PIL import Image
4
+ import os
5
+ import tempfile
6
+ import torch
7
+ from pathlib import Path
8
+ import secrets
9
+
10
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
11
+ "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
12
+ )
13
+
14
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
15
+ math_messages = []
16
+ def process_image(image, shouldConvert=False):
17
+ global math_messages
18
+ math_messages = [] # reset when upload image
19
+ uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
20
+ Path(tempfile.gettempdir()) / "gradio"
21
+ )
22
+ os.makedirs(uploaded_file_dir, exist_ok=True)
23
+
24
+ name = f"tmp{secrets.token_hex(20)}.jpg"
25
+ filename = os.path.join(uploaded_file_dir, name)
26
+ if shouldConvert:
27
+ new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
28
+ new_img.paste(image, (0, 0), mask=image)
29
+ image = new_img
30
+ image.save(filename)
31
+
32
+ messages = [{
33
+ 'role': 'system',
34
+ 'content': [{'text': 'You are a helpful assistant.'}]
35
+ }, {
36
+ 'role': 'user',
37
+ 'content': [
38
+ {'image': f'file://{filename}'},
39
+ {'text': 'Please describe the math-related content in this image, ensuring that any LaTeX formulas are correctly transcribed. Non-mathematical details do not need to be described.'}
40
+ ]
41
+ }]
42
+
43
+ text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
44
+
45
+ inputs = processor(
46
+ text = [text_prompt],
47
+ images = [image],
48
+ padding = True,
49
+ return_tensors = "pt"
50
+ )
51
+
52
+
53
+ output_ids = model.generate(**inputs, max_new_tokens=1024)
54
+
55
+ generated_ids = [
56
+ output_ids[len(input_ids) :]
57
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
58
+ ]
59
+
60
+ output_text = processor.batch_decode(
61
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
62
+ )
63
+
64
+ os.remove(filename)
65
+
66
+ return output_text
67
+
68
+
69
+
70
+ def get_math_response(image_description, user_question):
71
+ global math_messages
72
+ if not math_messages:
73
+ math_messages.append({'role': 'system', 'content': 'You are a helpful math assistant.'})
74
+ math_messages = math_messages[:1]
75
+ if image_description is not None:
76
+ content = f'Image description: {image_description}\n\n'
77
+ else:
78
+ content = ''
79
+ query = f"{content}User question: {user_question}"
80
+ math_messages.append({'role': 'user', 'content': query})
81
+ from transformers import AutoModelForCausalLM, AutoTokenizer
82
+
83
+ model_name = "Qwen/Qwen2-Math-72B-Instruct"
84
+ device = "cuda" # the device to load the model onto
85
+
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ model_name,
88
+ torch_dtype="auto",
89
+ device_map="auto"
90
+ )
91
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
92
+ text = tokenizer.apply_chat_template(
93
+ math_messages,
94
+ tokenize=False,
95
+ add_generation_prompt=True
96
+ )
97
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
98
+
99
+ generated_ids = model.generate(
100
+ **model_inputs,
101
+ max_new_tokens=512
102
+ )
103
+ generated_ids = [
104
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
105
+ ]
106
+
107
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
108
+
109
+ answer = None
110
+ for resp in response:
111
+ if resp.output is None:
112
+ continue
113
+ answer = resp.output.choices[0].message.content
114
+ yield answer.replace("\\", "\\\\")
115
+ print(f'query: {query}\nanswer: {answer}')
116
+ if answer is None:
117
+ math_messages.pop()
118
+ else:
119
+ math_messages.append({'role': 'assistant', 'content': answer})
120
+ def math_chat_bot(image, sketchpad, question, state):
121
+ current_tab_index = state["tab_index"]
122
+ image_description = None
123
+ # Upload
124
+ if current_tab_index == 0:
125
+ if image is not None:
126
+ image_description = process_image(image)
127
+ # Sketch
128
+ elif current_tab_index == 1:
129
+ print(sketchpad)
130
+ if sketchpad and sketchpad["composite"]:
131
+ image_description = process_image(sketchpad["composite"], True)
132
+ yield from get_math_response(image_description, question)
133
+
134
+ css = """
135
+ #qwen-md .katex-display { display: inline; }
136
+ #qwen-md .katex-display>.katex { display: inline; }
137
+ #qwen-md .katex-display>.katex>.katex-html { display: inline; }
138
+ """
139
+
140
+ def tabs_select(e: gr.SelectData, _state):
141
+ _state["tab_index"] = e.index
142
+
143
+
144
+ # 创建Gradio接口
145
+ with gr.Blocks(css=css) as demo:
146
+ gr.HTML("""\
147
+ <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/><p>"""
148
+ """<center><font size=8>📖 Qwen2-Math Demo</center>"""
149
+ """\
150
+ <center><font size=3>This WebUI is based on Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning. You can input either images or texts of mathematical or arithmetic problems.</center>"""
151
+ )
152
+ state = gr.State({"tab_index": 0})
153
+ with gr.Row():
154
+ with gr.Column():
155
+ with gr.Tabs() as input_tabs:
156
+ with gr.Tab("Upload"):
157
+ input_image = gr.Image(type="pil", label="Upload"),
158
+ with gr.Tab("Sketch"):
159
+ input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
160
+ input_tabs.select(fn=tabs_select, inputs=[state])
161
+ input_text = gr.Textbox(label="input your question")
162
+ with gr.Row():
163
+ with gr.Column():
164
+ clear_btn = gr.ClearButton(
165
+ [*input_image, input_sketchpad, input_text])
166
+ with gr.Column():
167
+ submit_btn = gr.Button("Submit", variant="primary")
168
+ with gr.Column():
169
+ output_md = gr.Markdown(label="answer",
170
+ latex_delimiters=[{
171
+ "left": "\\(",
172
+ "right": "\\)",
173
+ "display": True
174
+ }, {
175
+ "left": "\\begin\{equation\}",
176
+ "right": "\\end\{equation\}",
177
+ "display": True
178
+ }, {
179
+ "left": "\\begin\{align\}",
180
+ "right": "\\end\{align\}",
181
+ "display": True
182
+ }, {
183
+ "left": "\\begin\{alignat\}",
184
+ "right": "\\end\{alignat\}",
185
+ "display": True
186
+ }, {
187
+ "left": "\\begin\{gather\}",
188
+ "right": "\\end\{gather\}",
189
+ "display": True
190
+ }, {
191
+ "left": "\\begin\{CD\}",
192
+ "right": "\\end\{CD\}",
193
+ "display": True
194
+ }, {
195
+ "left": "\\[",
196
+ "right": "\\]",
197
+ "display": True
198
+ }],
199
+ elem_id="qwen-md")
200
+ submit_btn.click(
201
+ fn=math_chat_bot,
202
+ inputs=[*input_image, input_sketchpad, input_text, state],
203
+ outputs=output_md)
204
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ numpy==1.24.4
3
+ Pillow==10.3.0
4
+ Requests==2.31.0
5
+ torch
6
+ torchvision
7
+ transformers==4.43.0
8
+ accelerate==0.30.0
9
+ qwen-vl-utils