Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import tempfile | |
from pathlib import Path | |
import secrets | |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
import torch | |
# Set up models and processors | |
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2-VL-7B-Instruct", | |
torch_dtype="auto", | |
device_map="auto", | |
) | |
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
math_model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen2.5-Math-7B-Instruct", | |
torch_dtype="auto", | |
device_map="auto", | |
) | |
math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-7B-Instruct") | |
math_messages = [] | |
def process_image(image, should_convert=False): | |
""" | |
Processes the uploaded image and extracts math-related content using Qwen2-VL. | |
""" | |
global math_messages | |
math_messages = [] # Reset when uploading a new image | |
uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str( | |
Path(tempfile.gettempdir()) / "gradio" | |
) | |
os.makedirs(uploaded_file_dir, exist_ok=True) | |
name = f"tmp{secrets.token_hex(20)}.jpg" | |
filename = os.path.join(uploaded_file_dir, name) | |
if should_convert: | |
# Convert image to RGB if required | |
new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255)) | |
new_img.paste(image, (0, 0), mask=image) | |
image = new_img | |
image.save(filename) | |
# Prepare OCR input | |
messages = [ | |
{ | |
'role': 'system', | |
'content': [{'text': 'You are a helpful assistant.'}] | |
}, | |
{ | |
'role': 'user', | |
'content': [ | |
{'image': f'file://{filename}'}, | |
{'text': 'Please describe the math-related content in this image, ensuring that any LaTeX formulas are correctly transcribed. Non-mathematical details do not need to be described.'} | |
] | |
} | |
] | |
# Generate OCR output | |
text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt") | |
inputs = inputs.to("cpu") # Use CPU if GPU is unavailable | |
output_ids = ocr_model.generate(**inputs, max_new_tokens=1024) | |
output_text = ocr_processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] | |
os.remove(filename) | |
return output_text | |
def get_math_response(image_description, user_question): | |
""" | |
Sends the OCR output and user question to Qwen2-Math and retrieves the solution. | |
""" | |
global math_messages | |
# Initialize the math assistant role | |
if not math_messages: | |
math_messages.append({'role': 'system', 'content': 'You are a helpful math assistant.'}) | |
math_messages = math_messages[:1] # Retain only the system prompt | |
# Format the input question | |
if image_description is not None: | |
content = f'Image description: {image_description}\n\n' | |
else: | |
content = '' | |
query = f"{content}User question: {user_question}" | |
math_messages.append({'role': 'user', 'content': query}) | |
# Prepare math model input | |
inputs = math_tokenizer( | |
text=query, | |
padding=True, | |
return_tensors="pt" | |
).to("cpu") # Use CPU if GPU is unavailable | |
# Generate the math reasoning response | |
output_ids = math_model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
pad_token_id=math_tokenizer.pad_token_id | |
) | |
response = math_tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] | |
math_messages.append({'role': 'assistant', 'content': response}) # Append assistant response | |
return response | |
def math_chat_bot(image, sketchpad, question, state): | |
""" | |
Orchestrates the OCR (image processing) and math reasoning based on user input. | |
""" | |
current_tab_index = state["tab_index"] | |
image_description = None | |
# Upload tab | |
if current_tab_index == 0: | |
if image is not None: | |
image_description = process_image(image) | |
# Sketch tab | |
elif current_tab_index == 1: | |
if sketchpad and sketchpad["composite"]: | |
image_description = process_image(sketchpad["composite"], True) | |
response = get_math_response(image_description, question) | |
yield response | |
css = """ | |
#qwen-md .katex-display { display: inline; } | |
#qwen-md .katex-display>.katex { display: inline; } | |
#qwen-md .katex-display>.katex>.katex-html { display: inline; } | |
""" | |
def tabs_select(e: gr.SelectData, _state): | |
_state["tab_index"] = e.index | |
# Create Gradio interface | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
"""<center><h1>Qwen2-Math Demo</h1><p>Use either uploaded images or sketches for math-related problems.</p></center>""" | |
) | |
state = gr.State({"tab_index": 0}) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tabs() as input_tabs: | |
with gr.Tab("Upload"): | |
input_image = gr.Image(type="pil", label="Upload Image") | |
with gr.Tab("Sketch"): | |
input_sketchpad = gr.Sketchpad(label="Sketch Pad") | |
input_tabs.select(fn=lambda e: {"tab_index": e.index}, inputs=[], outputs=state) | |
input_text = gr.Textbox(label="Your Question") | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
output_md = gr.Markdown( | |
label="Answer", | |
latex_delimiters=[{ | |
"left": "\\(", | |
"right": "\\)", | |
"display": True | |
}, { | |
"left": "\\begin{equation}", | |
"right": "\\end{equation}", | |
"display": True | |
}, { | |
"left": "\\begin{align}", | |
"right": "\\end{align}", | |
"display": True | |
}, { | |
"left": "\\begin{alignat}", | |
"right": "\\end{alignat}", | |
"display": True | |
}, { | |
"left": "\\begin{gather}", | |
"right": "\\end{gather}", | |
"display": True | |
}, { | |
"left": "\\begin{CD}", | |
"right": "\\end{CD}", | |
"display": True | |
}, { | |
"left": "\\[", | |
"right": "\\]", | |
"display": True | |
}], | |
elem_id="qwen-md" | |
) | |
submit_btn.click( | |
fn=math_chat_bot, | |
inputs=[input_image, input_sketchpad, input_text, state], | |
outputs=output_md, | |
) | |
demo.launch() | |