Sketch / app.py
Jangai's picture
Update app.py
49a29d4 verified
raw
history blame
1.62 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import tempfile
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# Load model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
def display_sketch(sketch):
image_data = sketch['composite']
plt.imshow(image_data)
plt.axis('off')
temp_file_path = "/mnt/data/output.png"
plt.savefig(temp_file_path, bbox_inches='tight', pad_inches=0)
plt.close()
return temp_file_path
def recognize_text(image_path):
# Open image and convert to grayscale
image = Image.open(image_path).convert("L")
# Resize image to 256x256
image = image.resize((256, 256))
# Binarize image (convert to black and white)
image = image.point(lambda p: p > 128 and 255)
# Preprocess the image
pixel_values = processor(images=image, return_tensors="pt").pixel_values
# Generate prediction
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
with gr.Blocks() as demo:
sketchpad = gr.Sketchpad(label="Draw Something")
sketchpad_output = gr.Image(label="Your Sketch")
recognized_text = gr.Textbox(label="Recognized Text")
sketchpad.submit(display_sketch, inputs=sketchpad, outputs=sketchpad_output).then(
recognize_text, inputs=sketchpad_output, outputs=recognized_text
)
demo.launch()