Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoTokenizer, AutoConfig | |
import os | |
import base64 | |
import spaces | |
from loadimg import load_img | |
from PIL import Image | |
import numpy as np | |
title = """# 🙋🏻♂️Welcome to Tonic's🫴🏻📸GOT-OCR""" | |
description = """" | |
The GOT-OCR model is a revolutionary step in the evolution of OCR systems, boasting 580M parameters and the ability to process various forms of "characters." It features a high-compression encoder and a long-context decoder, making it well-suited for both scene- and document-style images. The model also supports multi-page and dynamic resolution OCR for added practicality. | |
The model can output results in a variety of formats, including plain text, markdown, and even complex outputs like TikZ diagrams or molecular SMILES strings. Interactive OCR allows users to specify regions of interest for OCR using coordinates or colors. | |
## Features | |
- **Plain Text OCR**: Recognizes and extracts plain text from images. | |
- **Formatted Text OCR**: Extracts text while preserving its formatting (tables, formulas, etc.). | |
- **Fine-grained OCR**: Box-based and color-based OCR for precise text extraction from specific regions. | |
- **Multi-crop OCR**: Processes multiple cropped regions within an image. | |
- **Rendered Formatted OCR Results**: Outputs OCR results in markdown, TikZ, SMILES, or other formats with rendered formatting. | |
GOT-OCR-2.0 can handle: | |
- Plain text | |
- Math/molecular formulas | |
- Tables | |
- Charts | |
- Sheet music | |
- Geometric shapes | |
## How to Use | |
1. Select a task from the dropdown menu. | |
2. Upload an image. | |
3. (Optional) Fill in additional parameters based on the task. | |
4. Click **Process** to see the results. | |
--- | |
### Join us : | |
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co./MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 | |
""" | |
model_name = 'ucaslcl/GOT-OCR2_0' | |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
model = model.eval().cuda() | |
model.config.pad_token_id = tokenizer.eos_token_id | |
def load_image(image_file): | |
if isinstance(image_file, str): | |
if image_file.startswith('http') or image_file.startswith('https'): | |
return Image.open(requests.get(image_file, stream=True).raw).convert('RGB') | |
else: | |
return Image.open(image_file).convert('RGB') | |
else: | |
return image_file.convert('RGB') | |
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None, render=False): | |
try: | |
img = load_image(image) | |
img_path = "/tmp/temp_image.png" | |
img.save(img_path) | |
if task == "Plain Text OCR": | |
res = model.chat(tokenizer, img_path, ocr_type='ocr') | |
elif task == "Format Text OCR": | |
res = model.chat(tokenizer, img_path, ocr_type='format') | |
elif task == "Fine-grained OCR (Box)": | |
res = model.chat(tokenizer, img_path, ocr_type=ocr_type, ocr_box=ocr_box) | |
elif task == "Fine-grained OCR (Color)": | |
res = model.chat(tokenizer, img_path, ocr_type=ocr_type, ocr_color=ocr_color) | |
elif task == "Multi-crop OCR": | |
res = model.chat_crop(tokenizer, image_file=img_path) | |
elif task == "Render Formatted OCR": | |
res = model.chat(tokenizer, img_path, ocr_type='format', render=True, save_render_file='./results/demo.html') | |
with open('./results/demo.html', 'r') as f: | |
html_content = f.read() | |
return res, html_content | |
# Clean up | |
os.remove(img_path) | |
return res, None | |
except Exception as e: | |
return str(e), None | |
def update_inputs(task): | |
if task == "Plain Text OCR" or task == "Format Text OCR" or task == "Multi-crop OCR": | |
return [gr.update(visible=False)] * 4 | |
elif task == "Fine-grained OCR (Box)": | |
return [ | |
gr.update(visible=True, choices=["ocr", "format"]), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
elif task == "Fine-grained OCR (Color)": | |
return [ | |
gr.update(visible=True, choices=["ocr", "format"]), | |
gr.update(visible=False), | |
gr.update(visible=True, choices=["red", "green", "blue"]), | |
gr.update(visible=False) | |
] | |
elif task == "Render Formatted OCR": | |
return [gr.update(visible=False)] * 3 + [gr.update(visible=True)] | |
def ocr_demo(image, task, ocr_type, ocr_box, ocr_color): | |
res, html_content = process_image(image, task, ocr_type, ocr_box, ocr_color) | |
if html_content: | |
return res, html_content | |
return res, None | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="filepath", label="Input Image") | |
task_dropdown = gr.Dropdown( | |
choices=[ | |
"Plain Text OCR", | |
"Format Text OCR", | |
"Fine-grained OCR (Box)", | |
"Fine-grained OCR (Color)", | |
"Multi-crop OCR", | |
"Render Formatted OCR" | |
], | |
label="Select Task", | |
value="Plain Text OCR" | |
) | |
ocr_type_dropdown = gr.Dropdown( | |
choices=["ocr", "format"], | |
label="OCR Type", | |
visible=False | |
) | |
ocr_box_input = gr.Textbox( | |
label="OCR Box (x1,y1,x2,y2)", | |
placeholder="e.g., 100,100,200,200", | |
visible=False | |
) | |
ocr_color_dropdown = gr.Dropdown( | |
choices=["red", "green", "blue"], | |
label="OCR Color", | |
visible=False | |
) | |
render_checkbox = gr.Checkbox( | |
label="Render Result", | |
visible=False | |
) | |
submit_button = gr.Button("Process") | |
with gr.Column(): | |
output_text = gr.Textbox(label="OCR Result") | |
output_html = gr.HTML(label="Rendered HTML Output") | |
gr.Markdown("""## GOT-OCR 2.0 | |
This small **330M parameter** model powerful OCR model can handle various text recognition tasks with high accuracy. | |
### Model Information | |
- **Model Name**: GOT-OCR 2.0 | |
- **Hugging Face Repository**: [ucaslcl/GOT-OCR2_0](https://huggingface.co./ucaslcl/GOT-OCR2_0) | |
- **Environment**: CUDA 11.8 + PyTorch 2.0.1 | |
""") | |
task_dropdown.change( | |
update_inputs, | |
inputs=[task_dropdown], | |
outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown, render_checkbox] | |
) | |
submit_button.click( | |
ocr_demo, | |
inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown], | |
outputs=[output_text, output_html] | |
) | |
if __name__ == "__main__": | |
demo.launch() |