Rialbox commited on
Commit
48e94a3
·
verified ·
1 Parent(s): 02e80c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -91
app.py CHANGED
@@ -1,128 +1,197 @@
1
  import gradio as gr
 
 
 
 
2
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
 
3
  import torch
4
 
5
- # Load the OCR model and processor
6
  ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
7
  "Qwen/Qwen2-VL-7B-Instruct",
8
  torch_dtype="auto",
9
  device_map="auto",
10
  )
11
-
12
  ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
13
 
14
- # Load the Math model and tokenizer
15
  math_model = AutoModelForCausalLM.from_pretrained(
16
- "Qwen/Qwen2.5-Math-72B-Instruct",
17
  torch_dtype="auto",
18
- device_map="auto"
19
  )
 
20
 
21
- math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-72B-Instruct")
22
 
23
- # OCR extraction function
24
- def ocr_and_query(image, question):
25
- # Prepare image for the model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  messages = [
27
  {
28
- "role": "user",
29
- "content": [
30
- {"type": "image"},
31
- {
32
- "type": "text",
33
- "text": question
34
- },
35
- ],
 
36
  }
37
  ]
38
 
39
- # Process image and text prompt
40
  text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True)
41
  inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
42
-
43
- # Run the model to generate OCR results
44
- inputs = inputs.to("cuda")
45
  output_ids = ocr_model.generate(**inputs, max_new_tokens=1024)
 
46
 
47
- # Decode the generated text
48
- generated_ids = [
49
- output_ids[len(input_ids):]
50
- for input_ids, output_ids in zip(inputs.input_ids, output_ids)
51
- ]
52
- output_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
53
-
54
  return output_text
55
 
56
- # Math problem solving function
57
- def solve_math_problem(prompt):
58
- # CoT (Chain of Thought)
59
- messages = [
60
- {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
61
- {"role": "user", "content": prompt}
62
- ]
63
 
64
- text = math_tokenizer.apply_chat_template(
65
- messages,
66
- tokenize=False,
67
- add_generation_prompt=True
68
- )
69
- model_inputs = math_tokenizer([text], return_tensors="pt").to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- generated_ids = math_model.generate(
72
- **model_inputs,
73
- max_new_tokens=512
 
 
74
  )
75
- generated_ids = [
76
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
77
- ]
78
 
79
- response = math_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
80
-
81
  return response
82
 
83
- # Function to clear inputs and output
84
- def clear_inputs():
85
- return None, "", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Gradio interface setup
88
- def gradio_app(image, question, task):
89
- if task == "OCR and Query":
90
- return image, question, ocr_and_query(image, question)
91
- elif task == "Solve Math Problem from Image":
92
- if image is None:
93
- return image, question, "Please upload an image."
94
- extracted_text = ocr_and_query(image, "")
95
- math_solution = solve_math_problem(extracted_text)
96
- return image, extracted_text, math_solution
97
- elif task == "Solve Math Problem from Text":
98
- if question.strip() == "":
99
- return image, question, "Please enter a math problem."
100
- math_solution = solve_math_problem(question)
101
- return image, question, math_solution
102
- else:
103
- return image, question, "Please select a task."
104
 
105
- # Gradio interface
106
- with gr.Blocks() as app:
107
- gr.Markdown("# Image OCR and Math Solver")
108
- gr.Markdown("Upload an image, enter your question or math problem, and select the appropriate task.")
109
-
110
- with gr.Row():
111
- image_input = gr.Image(type="pil", label="Upload Image")
112
- text_input = gr.Textbox(lines=2, placeholder="Enter your question or math problem here...", label="Input")
113
-
114
- with gr.Row():
115
- task_radio = gr.Radio(["OCR and Query", "Solve Math Problem from Image", "Solve Math Problem from Text"], label="Task")
116
-
117
  with gr.Row():
118
- complete_button = gr.Button("Complete")
119
- clear_button = gr.Button("Clear")
120
-
121
- output = gr.Markdown(label="Output")
122
-
123
- # Event listeners
124
- complete_button.click(fn=gradio_app, inputs=[image_input, text_input, task_radio], outputs=[image_input, text_input, output])
125
- clear_button.click(fn=clear_inputs, outputs=[image_input, text_input, output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- # Launch the app
128
- app.launch(share=True)
 
1
  import gradio as gr
2
+ import os
3
+ import tempfile
4
+ from pathlib import Path
5
+ import secrets
6
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
7
+ from PIL import Image
8
  import torch
9
 
10
+ # Set up models and processors
11
  ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
12
  "Qwen/Qwen2-VL-7B-Instruct",
13
  torch_dtype="auto",
14
  device_map="auto",
15
  )
 
16
  ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
17
 
 
18
  math_model = AutoModelForCausalLM.from_pretrained(
19
+ "Qwen/Qwen2.5-Math-7B-Instruct",
20
  torch_dtype="auto",
21
+ device_map="auto",
22
  )
23
+ math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-7B-Instruct")
24
 
25
+ math_messages = []
26
 
27
+ def process_image(image, should_convert=False):
28
+ """
29
+ Processes the uploaded image and extracts math-related content using Qwen2-VL.
30
+ """
31
+ global math_messages
32
+ math_messages = [] # Reset when uploading a new image
33
+ uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
34
+ Path(tempfile.gettempdir()) / "gradio"
35
+ )
36
+ os.makedirs(uploaded_file_dir, exist_ok=True)
37
+
38
+ name = f"tmp{secrets.token_hex(20)}.jpg"
39
+ filename = os.path.join(uploaded_file_dir, name)
40
+
41
+ if should_convert:
42
+ # Convert image to RGB if required
43
+ new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
44
+ new_img.paste(image, (0, 0), mask=image)
45
+ image = new_img
46
+ image.save(filename)
47
+
48
+ # Prepare OCR input
49
  messages = [
50
  {
51
+ 'role': 'system',
52
+ 'content': [{'text': 'You are a helpful assistant.'}]
53
+ },
54
+ {
55
+ 'role': 'user',
56
+ 'content': [
57
+ {'image': f'file://{filename}'},
58
+ {'text': 'Please describe the math-related content in this image, ensuring that any LaTeX formulas are correctly transcribed. Non-mathematical details do not need to be described.'}
59
+ ]
60
  }
61
  ]
62
 
63
+ # Generate OCR output
64
  text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True)
65
  inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
66
+ inputs = inputs.to("cuda") # Use CPU if GPU is unavailable
 
 
67
  output_ids = ocr_model.generate(**inputs, max_new_tokens=1024)
68
+ output_text = ocr_processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
69
 
70
+ os.remove(filename)
 
 
 
 
 
 
71
  return output_text
72
 
73
+ def get_math_response(image_description, user_question):
74
+ """
75
+ Sends the OCR output and user question to Qwen2-Math and retrieves the solution.
76
+ """
77
+ global math_messages
 
 
78
 
79
+ # Initialize the math assistant role
80
+ if not math_messages:
81
+ math_messages.append({'role': 'system', 'content': 'You are a helpful math assistant.'})
82
+ math_messages = math_messages[:1] # Retain only the system prompt
83
+
84
+ # Format the input question
85
+ if image_description is not None:
86
+ content = f'Image description: {image_description}\n\n'
87
+ else:
88
+ content = ''
89
+ query = f"{content}User question: {user_question}"
90
+ math_messages.append({'role': 'user', 'content': query})
91
+
92
+ # Prepare math model input
93
+ inputs = math_tokenizer(
94
+ text=query,
95
+ padding=True,
96
+ return_tensors="pt"
97
+ ).to("cuda") # Use CPU if GPU is unavailable
98
 
99
+ # Generate the math reasoning response
100
+ output_ids = math_model.generate(
101
+ **inputs,
102
+ max_new_tokens=1024,
103
+ pad_token_id=math_tokenizer.pad_token_id
104
  )
105
+ response = math_tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
106
+ math_messages.append({'role': 'assistant', 'content': response}) # Append assistant response
 
107
 
 
 
108
  return response
109
 
110
+ def math_chat_bot(image, sketchpad, question, state):
111
+ """
112
+ Orchestrates the OCR (image processing) and math reasoning based on user input.
113
+ """
114
+ current_tab_index = state["tab_index"]
115
+ image_description = None
116
+ # Upload tab
117
+ if current_tab_index == 0:
118
+ if image is not None:
119
+ image_description = process_image(image)
120
+ # Sketch tab
121
+ elif current_tab_index == 1:
122
+ if sketchpad and sketchpad["composite"]:
123
+ image_description = process_image(sketchpad["composite"], True)
124
+
125
+ response = get_math_response(image_description, question)
126
+ yield response
127
 
128
+ css = """
129
+ #qwen-md .katex-display { display: inline; }
130
+ #qwen-md .katex-display>.katex { display: inline; }
131
+ #qwen-md .katex-display>.katex>.katex-html { display: inline; }
132
+ """
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ def tabs_select(e: gr.SelectData, _state):
135
+ _state["tab_index"] = e.index
136
+
137
+ # Create Gradio interface
138
+ with gr.Blocks(css=css) as demo:
139
+ gr.HTML("""\
140
+ <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/><p>"""
141
+ """<center><font size=8>📖 Qwen2-Math Demo</center>"""
142
+ """\
143
+ <center><font size=3>This WebUI is based on Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning. You can input either images or texts of mathematical or arithmetic problems.</center>"""
144
+ )
145
+ state = gr.State({"tab_index": 0})
146
  with gr.Row():
147
+ with gr.Column():
148
+ with gr.Tabs() as input_tabs:
149
+ with gr.Tab("Upload"):
150
+ input_image = gr.Image(type="pil", label="Upload")
151
+ with gr.Tab("Sketch"):
152
+ input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
153
+ input_tabs.select(fn=tabs_select, inputs=[state])
154
+ input_text = gr.Textbox(label="Input your question")
155
+ with gr.Row():
156
+ with gr.Column():
157
+ clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
158
+ with gr.Column():
159
+ submit_btn = gr.Button("Submit", variant="primary")
160
+ with gr.Column():
161
+ output_md = gr.Markdown(label="answer",
162
+ latex_delimiters=[{
163
+ "left": "\\(",
164
+ "right": "\\)",
165
+ "display": True
166
+ }, {
167
+ "left": "\\begin\{equation\}",
168
+ "right": "\\end\{equation\}",
169
+ "display": True
170
+ }, {
171
+ "left": "\\begin\{align\}",
172
+ "right": "\\end\{align\}",
173
+ "display": True
174
+ }, {
175
+ "left": "\\begin\{alignat\}",
176
+ "right": "\\end\{alignat\}",
177
+ "display": True
178
+ }, {
179
+ "left": "\\begin\{gather\}",
180
+ "right": "\\end\{gather\}",
181
+ "display": True
182
+ }, {
183
+ "left": "\\begin\{CD\}",
184
+ "right": "\\end\{CD\}",
185
+ "display": True
186
+ }, {
187
+ "left": "\\[",
188
+ "right": "\\]",
189
+ "display": True
190
+ }],
191
+ elem_id="qwen-md")
192
+ submit_btn.click(
193
+ fn=math_chat_bot,
194
+ inputs=[input_image, input_sketchpad, input_text, state],
195
+ outputs=output_md)
196
 
197
+ demo.launch(share=True)