File size: 1,618 Bytes
5e4488c
2fd0f3a
cc24739
 
11b9b6a
cc24739
11b9b6a
3dd9291
cc24739
11b9b6a
 
6bd6ea4
2fd0f3a
cc24739
 
 
 
 
 
 
 
 
7e53392
6bd6ea4
cc24739
 
 
 
 
 
11b9b6a
cc24739
11b9b6a
cc24739
11b9b6a
 
 
 
2fd0f3a
49a29d4
cc24739
6bd6ea4
11b9b6a
6addaca
cc24739
 
cff0816
2fd0f3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import tempfile
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# Load model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")

def display_sketch(sketch):
    image_data = sketch['composite']
    plt.imshow(image_data)
    plt.axis('off')

    temp_file_path = "/mnt/data/output.png"
    plt.savefig(temp_file_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    return temp_file_path

def recognize_text(image_path):
    # Open image and convert to grayscale
    image = Image.open(image_path).convert("L")
    # Resize image to 256x256
    image = image.resize((256, 256))
    # Binarize image (convert to black and white)
    image = image.point(lambda p: p > 128 and 255)
    
    # Preprocess the image
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    # Generate prediction
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

with gr.Blocks() as demo:
    sketchpad = gr.Sketchpad(label="Draw Something")
    sketchpad_output = gr.Image(label="Your Sketch")
    recognized_text = gr.Textbox(label="Recognized Text")
    
    sketchpad.change(display_sketch, inputs=sketchpad, outputs=sketchpad_output).then(
        recognize_text, inputs=sketchpad_output, outputs=recognized_text
    )

demo.launch()