medieval-htr / app.py
wjbmattingly's picture
Update app.py
dabac75 verified
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import spaces
# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
"Microsoft Handwritten": "microsoft/trocr-base-handwritten",
"Medieval Base": "medieval-data/trocr-medieval-base",
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
"Medieval Textualis": "medieval-data/trocr-medieval-textualis",
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
"Medieval Print": "medieval-data/trocr-medieval-print"
}
# Global variables to store the current model and processor
current_model = None
current_processor = None
current_model_name = None
def load_model(model_name):
global current_model, current_processor, current_model_name
if model_name != current_model_name:
model_id = MODEL_OPTIONS[model_name]
current_processor = TrOCRProcessor.from_pretrained(model_id)
current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
current_model_name = model_name
# Move model to GPU
current_model = current_model.to('cuda')
return current_processor, current_model
@spaces.GPU
def process_image(image, model_name):
processor, model = load_model(model_name)
# Prepare image
pixel_values = processor(image, return_tensors="pt").pixel_values
# Move input to GPU
pixel_values = pixel_values.to('cuda')
# Generate (no beam search)
with torch.no_grad():
generated_ids = model.generate(pixel_values)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Base URL for the images
base_url = "https://huggingface.co./medieval-data/trocr-medieval-base/resolve/main/images/"
# List of example images and their corresponding models
examples = [
[f"{base_url}caroline-1.png", "Medieval Latin Caroline"],
[f"{base_url}caroline-2.png", "Medieval Latin Caroline"],
[f"{base_url}cursiva-1.png", "Medieval Cursiva"],
[f"{base_url}cursiva-2.png", "Medieval Cursiva"],
[f"{base_url}cursiva-3.png", "Medieval Cursiva"],
[f"{base_url}humanistica-1.png", "Medieval Humanistica"],
[f"{base_url}humanistica-2.png", "Medieval Humanistica"],
[f"{base_url}humanistica-3.png", "Medieval Humanistica"],
[f"{base_url}hybrida-1.png", "Medieval Castilian Hybrida"],
[f"{base_url}hybrida-2.png", "Medieval Castilian Hybrida"],
[f"{base_url}hybrida-3.png", "Medieval Castilian Hybrida"],
[f"{base_url}praegothica-1.png", "Medieval Praegothica"],
[f"{base_url}praegothica-2.png", "Medieval Praegothica"],
[f"{base_url}praegothica-3.png", "Medieval Praegothica"],
[f"{base_url}print-1.png", "Medieval Print"],
[f"{base_url}print-2.png", "Medieval Print"],
[f"{base_url}print-3.png", "Medieval Print"],
[f"{base_url}semihybrida-1.png", "Medieval Semihybrida"],
[f"{base_url}semihybrida-2.png", "Medieval Semihybrida"],
[f"{base_url}semihybrida-3.png", "Medieval Semihybrida"],
[f"{base_url}semitextualis-1.png", "Medieval Semitextualis"],
[f"{base_url}semitextualis-2.png", "Medieval Semitextualis"],
[f"{base_url}semitextualis-3.png", "Medieval Semitextualis"],
[f"{base_url}textualis-1.png", "Medieval Textualis"],
[f"{base_url}textualis-2.png", "Medieval Textualis"],
[f"{base_url}textualis-3.png", "Medieval Textualis"],
]
# Custom CSS to make the image wider
custom_css = """
#image_upload {
max-width: 100% !important;
width: 100% !important;
height: auto !important;
}
#image_upload > div:first-child {
width: 100% !important;
}
#image_upload img {
max-width: 100% !important;
width: 100% !important;
height: auto !important;
}
"""
# Gradio interface
with gr.Blocks(css=custom_css) as iface:
gr.Markdown("# Medieval TrOCR Model Switcher")
gr.Markdown("Upload an image of medieval text and select a model to transcribe it. Note: This tool is designed to work on a single line of text at a time for optimal results.")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image", elem_id="image_upload")
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
transcription_output = gr.Textbox(label="Transcription")
submit_button = gr.Button("Transcribe")
submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=transcription_output)
gr.Examples(examples, inputs=[input_image, model_dropdown], outputs=transcription_output)
iface.launch()