Spaces:
Sleeping
Sleeping
File size: 6,806 Bytes
f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 ce6a793 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 1200067 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 f4ed285 48e94a3 1e979ad 48e94a3 f4ed285 48e94a3 1e979ad 48e94a3 1e979ad 48e94a3 1e979ad 48e94a3 1e979ad f4ed285 289fc49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
import os
import tempfile
from pathlib import Path
import secrets
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
# Set up models and processors
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
device_map="auto",
)
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
math_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-Math-7B-Instruct",
torch_dtype="auto",
device_map="auto",
)
math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-7B-Instruct")
math_messages = []
def process_image(image, should_convert=False):
"""
Processes the uploaded image and extracts math-related content using Qwen2-VL.
"""
global math_messages
math_messages = [] # Reset when uploading a new image
uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
os.makedirs(uploaded_file_dir, exist_ok=True)
name = f"tmp{secrets.token_hex(20)}.jpg"
filename = os.path.join(uploaded_file_dir, name)
if should_convert:
# Convert image to RGB if required
new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
new_img.paste(image, (0, 0), mask=image)
image = new_img
image.save(filename)
# Prepare OCR input
messages = [
{
'role': 'system',
'content': [{'text': 'You are a helpful assistant.'}]
},
{
'role': 'user',
'content': [
{'image': f'file://{filename}'},
{'text': 'Please describe the math-related content in this image, ensuring that any LaTeX formulas are correctly transcribed. Non-mathematical details do not need to be described.'}
]
}
]
# Generate OCR output
text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to("cpu") # Use CPU if GPU is unavailable
output_ids = ocr_model.generate(**inputs, max_new_tokens=1024)
output_text = ocr_processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
os.remove(filename)
return output_text
def get_math_response(image_description, user_question):
"""
Sends the OCR output and user question to Qwen2-Math and retrieves the solution.
"""
global math_messages
# Initialize the math assistant role
if not math_messages:
math_messages.append({'role': 'system', 'content': 'You are a helpful math assistant.'})
math_messages = math_messages[:1] # Retain only the system prompt
# Format the input question
if image_description is not None:
content = f'Image description: {image_description}\n\n'
else:
content = ''
query = f"{content}User question: {user_question}"
math_messages.append({'role': 'user', 'content': query})
# Prepare math model input
inputs = math_tokenizer(
text=query,
padding=True,
return_tensors="pt"
).to("cpu") # Use CPU if GPU is unavailable
# Generate the math reasoning response
output_ids = math_model.generate(
**inputs,
max_new_tokens=1024,
pad_token_id=math_tokenizer.pad_token_id
)
response = math_tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
math_messages.append({'role': 'assistant', 'content': response}) # Append assistant response
return response
def math_chat_bot(image, sketchpad, question, state):
"""
Orchestrates the OCR (image processing) and math reasoning based on user input.
"""
current_tab_index = state["tab_index"]
image_description = None
# Upload tab
if current_tab_index == 0:
if image is not None:
image_description = process_image(image)
# Sketch tab
elif current_tab_index == 1:
if sketchpad and sketchpad["composite"]:
image_description = process_image(sketchpad["composite"], True)
response = get_math_response(image_description, question)
yield response
css = """
#qwen-md .katex-display { display: inline; }
#qwen-md .katex-display>.katex { display: inline; }
#qwen-md .katex-display>.katex>.katex-html { display: inline; }
"""
def tabs_select(e: gr.SelectData, _state):
_state["tab_index"] = e.index
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.HTML(
"""<center><h1>Qwen2-Math Demo</h1><p>Use either uploaded images or sketches for math-related problems.</p></center>"""
)
state = gr.State({"tab_index": 0})
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.Tab("Upload"):
input_image = gr.Image(type="pil", label="Upload Image")
with gr.Tab("Sketch"):
input_sketchpad = gr.Sketchpad(label="Sketch Pad")
input_tabs.select(fn=lambda e: {"tab_index": e.index}, inputs=[], outputs=state)
input_text = gr.Textbox(label="Your Question")
submit_btn = gr.Button("Submit")
with gr.Column():
output_md = gr.Markdown(
label="Answer",
latex_delimiters=[{
"left": "\\(",
"right": "\\)",
"display": True
}, {
"left": "\\begin{equation}",
"right": "\\end{equation}",
"display": True
}, {
"left": "\\begin{align}",
"right": "\\end{align}",
"display": True
}, {
"left": "\\begin{alignat}",
"right": "\\end{alignat}",
"display": True
}, {
"left": "\\begin{gather}",
"right": "\\end{gather}",
"display": True
}, {
"left": "\\begin{CD}",
"right": "\\end{CD}",
"display": True
}, {
"left": "\\[",
"right": "\\]",
"display": True
}],
elem_id="qwen-md"
)
submit_btn.click(
fn=math_chat_bot,
inputs=[input_image, input_sketchpad, input_text, state],
outputs=output_md,
)
demo.launch()
|