File size: 4,387 Bytes
f4ed285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4d3440
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
import torch

# Load the OCR model and processor
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype="auto",
    device_map="auto",
)

ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

# Load the Math model and tokenizer
math_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-Math-72B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)

math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-72B-Instruct")

# OCR extraction function
def ocr_and_query(image, question):
    # Prepare image for the model
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {
                    "type": "text",
                    "text": question
                },
            ],
        }
    ]
    
    # Process image and text prompt
    text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
    
    # Run the model to generate OCR results
    inputs = inputs.to("cuda")
    output_ids = ocr_model.generate(**inputs, max_new_tokens=1024)
    
    # Decode the generated text
    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(inputs.input_ids, output_ids)
    ]
    output_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
    
    return output_text

# Math problem solving function
def solve_math_problem(prompt):
    # CoT (Chain of Thought)
    messages = [
        {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
        {"role": "user", "content": prompt}
    ]

    text = math_tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = math_tokenizer([text], return_tensors="pt").to("cuda")

    generated_ids = math_model.generate(
        **model_inputs,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = math_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return response

# Function to clear inputs and output
def clear_inputs():
    return None, "", ""

# Gradio interface setup
def gradio_app(image, question, task):
    if task == "OCR and Query":
        return image, question, ocr_and_query(image, question)
    elif task == "Solve Math Problem from Image":
        if image is None:
            return image, question, "Please upload an image."
        extracted_text = ocr_and_query(image, "")
        math_solution = solve_math_problem(extracted_text)
        return image, extracted_text, math_solution
    elif task == "Solve Math Problem from Text":
        if question.strip() == "":
            return image, question, "Please enter a math problem."
        math_solution = solve_math_problem(question)
        return image, question, math_solution
    else:
        return image, question, "Please select a task."

# Gradio interface
with gr.Blocks() as app:
    gr.Markdown("# Image OCR and Math Solver")
    gr.Markdown("Upload an image, enter your question or math problem, and select the appropriate task.")
    
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Image")
        text_input = gr.Textbox(lines=2, placeholder="Enter your question or math problem here...", label="Input")
    
    with gr.Row():
        task_radio = gr.Radio(["OCR and Query", "Solve Math Problem from Image", "Solve Math Problem from Text"], label="Task")
    
    with gr.Row():
        complete_button = gr.Button("Complete")
        clear_button = gr.Button("Clear")
    
    output = gr.Markdown(label="Output")
    
    # Event listeners
    complete_button.click(fn=gradio_app, inputs=[image_input, text_input, task_radio], outputs=[image_input, text_input, output])
    clear_button.click(fn=clear_inputs, outputs=[image_input, text_input, output])

# Launch the app
app.launch(share=True)