import gradio as gr import os import torch from PIL import Image from transformers import TrOCRProcessor, VisionEncoderDecoderModel # Set up device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the fine-tuned model checkpoint_path = './checkpoint-2070' # Path to your fine-tuned model checkpoint model = VisionEncoderDecoderModel.from_pretrained(checkpoint_path).to(device) # Use the original model's processor (tokenizer and feature extractor) processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten") def ocr_image(image): """ Perform OCR on a single image. :param image: PIL Image object. :return: Extracted text from the image. """ pixel_values = processor(image, return_tensors='pt').pixel_values.to(device) generated_ids = model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text # Define the Gradio interface interface = gr.Interface( fn=ocr_image, # Function to call for prediction inputs=gr.inputs.Image(type="pil"), # Accept an image as input outputs="text", # Return extracted text title="OCR with TrOCR", description="Upload an image, and the fine-tuned TrOCR model will extract the text for you." ) # Launch the Gradio app if __name__ == "__main__": interface.launch()