wjbmattingly commited on
Commit
f8ba7b0
1 Parent(s): c10450c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ import torch
4
+ import spaces
5
+ import subprocess
6
+ import json
7
+ from PIL import Image, ImageDraw
8
+ import os
9
+ import tempfile
10
+
11
+ # Dictionary of model names and their corresponding HuggingFace model IDs
12
+ MODEL_OPTIONS = {
13
+ "Microsoft Handwritten": "microsoft/trocr-base-handwritten",
14
+ "Medieval Base": "medieval-data/trocr-medieval-base",
15
+ "Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
16
+ "Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
17
+ "Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
18
+ "Medieval Textualis": "medieval-data/trocr-medieval-textualis",
19
+ "Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
20
+ "Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
21
+ "Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
22
+ "Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
23
+ "Medieval Print": "medieval-data/trocr-medieval-print"
24
+ }
25
+
26
+ # Global variables to store the current model and processor
27
+ current_model = None
28
+ current_processor = None
29
+ current_model_name = None
30
+
31
+ def load_model(model_name):
32
+ global current_model, current_processor, current_model_name
33
+
34
+ if model_name != current_model_name:
35
+ model_id = MODEL_OPTIONS[model_name]
36
+ current_processor = TrOCRProcessor.from_pretrained(model_id)
37
+ current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
38
+ current_model_name = model_name
39
+
40
+ # Move model to GPU
41
+ current_model = current_model.to('cuda')
42
+
43
+ return current_processor, current_model
44
+
45
+ @spaces.GPU
46
+ def process_image(image, model_name):
47
+ # Save the uploaded image to a temporary file
48
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
49
+ image.save(temp_img, format="JPEG")
50
+ temp_img_path = temp_img.name
51
+
52
+ # Run Kraken for line detection
53
+ lines_json_path = "lines.json"
54
+ kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl"
55
+ subprocess.run(kraken_command, shell=True, check=True)
56
+
57
+ # Load the lines from the JSON file
58
+ with open(lines_json_path, 'r') as f:
59
+ lines_data = json.load(f)
60
+
61
+ processor, model = load_model(model_name)
62
+
63
+ # Process each line
64
+ transcriptions = []
65
+ for line in lines_data['lines']:
66
+ # Extract line coordinates
67
+ x1, y1 = line['baseline'][0]
68
+ x2, y2 = line['baseline'][-1]
69
+
70
+ # Crop the line from the original image
71
+ line_image = image.crop((x1, y1, x2, y2))
72
+
73
+ # Prepare image for TrOCR
74
+ pixel_values = processor(line_image, return_tensors="pt").pixel_values
75
+ pixel_values = pixel_values.to('cuda')
76
+
77
+ # Generate (no beam search)
78
+ with torch.no_grad():
79
+ generated_ids = model.generate(pixel_values)
80
+
81
+ # Decode
82
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
83
+ transcriptions.append(generated_text)
84
+
85
+ # Clean up temporary files
86
+ os.unlink(temp_img_path)
87
+ os.unlink(lines_json_path)
88
+
89
+ # Create an image with bounding boxes
90
+ draw = ImageDraw.Draw(image)
91
+ for line in lines_data['lines']:
92
+ coords = line['baseline']
93
+ draw.line(coords, fill="red", width=2)
94
+
95
+ return image, "\n".join(transcriptions)
96
+
97
+ # Gradio interface
98
+ with gr.Blocks() as iface:
99
+ gr.Markdown("# Medieval Document Transcription")
100
+ gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.")
101
+
102
+ with gr.Row():
103
+ input_image = gr.Image(type="pil", label="Input Image")
104
+ model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
105
+
106
+ with gr.Row():
107
+ output_image = gr.Image(type="pil", label="Detected Lines")
108
+ transcription_output = gr.Textbox(label="Transcription", lines=10)
109
+
110
+ submit_button = gr.Button("Transcribe")
111
+ submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
112
+
113
+ iface.launch()