|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import tempfile |
|
from PIL import Image |
|
import torch |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten") |
|
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten") |
|
|
|
def display_sketch(sketch): |
|
image_data = sketch['composite'] |
|
plt.imshow(image_data) |
|
plt.axis('off') |
|
|
|
temp_file_path = "/mnt/data/output.png" |
|
plt.savefig(temp_file_path, bbox_inches='tight', pad_inches=0) |
|
plt.close() |
|
|
|
return temp_file_path |
|
|
|
def recognize_text(image_path): |
|
|
|
image = Image.open(image_path).convert("L") |
|
|
|
image = image.resize((256, 256)) |
|
|
|
image = image.point(lambda p: p > 128 and 255) |
|
|
|
|
|
pixel_values = processor(images=image, return_tensors="pt").pixel_values |
|
|
|
generated_ids = model.generate(pixel_values) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
with gr.Blocks() as demo: |
|
sketchpad = gr.Sketchpad(label="Draw Something") |
|
sketchpad_output = gr.Image(label="Your Sketch") |
|
recognized_text = gr.Textbox(label="Recognized Text") |
|
|
|
sketchpad.change(display_sketch, inputs=sketchpad, outputs=sketchpad_output).then( |
|
recognize_text, inputs=sketchpad_output, outputs=recognized_text |
|
) |
|
|
|
demo.launch() |
|
|