Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Load the OCR model and processor | |
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2-VL-7B-Instruct", | |
torch_dtype="auto", | |
device_map="auto", | |
) | |
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
# Load the Math model and tokenizer | |
math_model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen2.5-Math-72B-Instruct", | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-72B-Instruct") | |
# OCR extraction function | |
def ocr_and_query(image, question): | |
# Prepare image for the model | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{ | |
"type": "text", | |
"text": question | |
}, | |
], | |
} | |
] | |
# Process image and text prompt | |
text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt") | |
# Run the model to generate OCR results | |
inputs = inputs.to("cuda") | |
output_ids = ocr_model.generate(**inputs, max_new_tokens=1024) | |
# Decode the generated text | |
generated_ids = [ | |
output_ids[len(input_ids):] | |
for input_ids, output_ids in zip(inputs.input_ids, output_ids) | |
] | |
output_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] | |
return output_text | |
# Math problem solving function | |
def solve_math_problem(prompt): | |
# CoT (Chain of Thought) | |
messages = [ | |
{"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."}, | |
{"role": "user", "content": prompt} | |
] | |
text = math_tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = math_tokenizer([text], return_tensors="pt").to("cuda") | |
generated_ids = math_model.generate( | |
**model_inputs, | |
max_new_tokens=512 | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = math_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return response | |
# Function to clear inputs and output | |
def clear_inputs(): | |
return None, "", "" | |
# Gradio interface setup | |
def gradio_app(image, question, task): | |
if task == "OCR and Query": | |
return image, question, ocr_and_query(image, question) | |
elif task == "Solve Math Problem from Image": | |
if image is None: | |
return image, question, "Please upload an image." | |
extracted_text = ocr_and_query(image, "") | |
math_solution = solve_math_problem(extracted_text) | |
return image, extracted_text, math_solution | |
elif task == "Solve Math Problem from Text": | |
if question.strip() == "": | |
return image, question, "Please enter a math problem." | |
math_solution = solve_math_problem(question) | |
return image, question, math_solution | |
else: | |
return image, question, "Please select a task." | |
# Gradio interface | |
with gr.Blocks() as app: | |
gr.Markdown("# Image OCR and Math Solver") | |
gr.Markdown("Upload an image, enter your question or math problem, and select the appropriate task.") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
text_input = gr.Textbox(lines=2, placeholder="Enter your question or math problem here...", label="Input") | |
with gr.Row(): | |
task_radio = gr.Radio(["OCR and Query", "Solve Math Problem from Image", "Solve Math Problem from Text"], label="Task") | |
with gr.Row(): | |
complete_button = gr.Button("Complete") | |
clear_button = gr.Button("Clear") | |
output = gr.Markdown(label="Output") | |
# Event listeners | |
complete_button.click(fn=gradio_app, inputs=[image_input, text_input, task_radio], outputs=[image_input, text_input, output]) | |
clear_button.click(fn=clear_inputs, outputs=[image_input, text_input, output]) | |
# Launch the app | |
app.launch(share=True) | |