Pclanglais's picture
Update app.py
8abf603 verified
raw
history blame
3.09 kB
import transformers
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import difflib
import spaces
from concurrent.futures import ThreadPoolExecutor
import os
# OCR Correction Model
model_name = "PleIAs/OCRonos-Vintage"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# CSS for formatting
css = """
<style>
.generation {
margin-left: 2em;
margin-right: 2em;
font-size: 1.2em;
}
.inserted {
background-color: #90EE90;
}
</style>
"""
def generate_html_diff(old_text, new_text):
d = difflib.Differ()
diff = list(d.compare(old_text.split(), new_text.split()))
html_diff = []
for word in diff:
if word.startswith(' '):
html_diff.append(word[2:])
elif word.startswith('+ '):
html_diff.append(f'<span class="inserted">{word[2:]}</span>')
return ' '.join(html_diff)
def split_text(text, max_tokens=400):
tokens = tokenizer.tokenize(text)
chunks = []
current_chunk = []
for token in tokens:
current_chunk.append(token)
if len(current_chunk) >= max_tokens:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
current_chunk = []
if current_chunk:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
return chunks
@spaces.GPU
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output = model.generate(input_ids,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=50,
num_return_sequences=1,
do_sample=False
)
result = tokenizer.decode(output[0], skip_special_tokens=True)
return result.split("### Correction ###")[1].strip()
def process_text(user_message):
chunks = split_text(user_message)
corrected_chunks = []
for chunk in chunks:
corrected_chunk = ocr_correction(chunk)
corrected_chunks.append(corrected_chunk)
corrected_text = ' '.join(corrected_chunks)
html_diff = generate_html_diff(user_message, corrected_text)
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector (Zero-GPU)</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(process_text, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch()