import streamlit as st |
import torch |
from PIL import Image |
import requests |
from io import BytesIO |
from transformers import AutoProcessor |
import sys |
import os |
import importlib.util |
from huggingface_hub import snapshot_download |
import easyocr |
import re |
from typing import Tuple |
import numpy as np |
torch.cuda.is_available = lambda: False |
st.set_page_config(page_title="Bilingual OCR App", layout="wide") |
@st.cache_resource |
def setup_got_model() -> Tuple[object, object]: |
st.info("Using CPU for computation. This may take longer for processing.") |
model_path = snapshot_download("ucaslcl/GOT-OCR2_0", revision="main") |
spec = importlib.util.spec_from_file_location("modeling_GOT", os.path.join(model_path, "modeling_GOT.py")) |
modeling_GOT = importlib.util.module_from_spec(spec) |
spec.loader.exec_module(modeling_GOT) |
GOTVisionT5ForConditionalGeneration = modeling_GOT.GOTVisionT5ForConditionalGeneration |
model = GOTVisionT5ForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True) |
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
return model, processor |
@st.cache_resource |
def setup_easyocr() -> easyocr.Reader: |
reader = easyocr.Reader(['en', 'hi'], gpu=False) |
return reader |
def perform_got_ocr(model, processor, image: Image.Image) -> str: |
inputs = processor(images=image, return_tensors="pt") |
generated_ids = model.generate(**inputs, max_length=512) |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
return generated_text |
def perform_easyocr(ocr_reader, image: np.ndarray) -> str: |
results = ocr_reader.readtext(image, detail=0, paragraph=True) |
extracted_text = '\n'.join(results) |
return extracted_text |
def highlight_keywords(text: str, keywords: str) -> str: |
escaped_keywords = re.escape(keywords) |
pattern = re.compile(f"({escaped_keywords})", re.IGNORECASE) |
highlighted_text = pattern.sub(r'<mark>\1</mark>', text) |
return highlighted_text |
def main(): |
try: |
with st.spinner("Setting up GOT model... This may take a few minutes."): |
model, processor = setup_got_model() |
ocr_model = 'GOT-OCR2_0' |
except Exception as e: |
st.warning(f"GOT model failed to load: {str(e)}. Falling back to EasyOCR.") |
model, processor = None, None |
ocr_reader = setup_easyocr() |
ocr_model = 'EasyOCR' |
st.title("Bilingual OCR Application") |
st.sidebar.header("Instructions") |
st.sidebar.markdown(""" |
1. **Upload an image** containing text in Hindi and/or English. |
2. **Extracted text** will be displayed below. |
3. **Enter keywords** to search within the extracted text. |
4. **Matching sections** will be highlighted in the results. |
""") |
uploaded_file = st.file_uploader("Upload an Image for OCR", type=['png', 'jpg', 'jpeg']) |
if uploaded_file is not None: |
try: |
image_bytes = uploaded_file.read() |
image_pil = Image.open(BytesIO(image_bytes)).convert('RGB') |
image_np = np.array(image_pil) |
st.image(image_pil, caption="Uploaded Image", use_column_width=True) |
with st.spinner(f"Performing OCR using {ocr_model}..."): |
if ocr_model == 'GOT-OCR2_0' and model and processor: |
extracted_text = perform_got_ocr(model, processor, image_pil) |
else: |
extracted_text = perform_easyocr(ocr_reader, image_np) |
st.subheader("Extracted Text:") |
st.text_area("Text", extracted_text, height=200) |
keyword = st.text_input("Enter keyword to search in extracted text:") |
if keyword: |
highlighted = highlight_keywords(extracted_text, keyword) |
st.subheader("Search Results:") |
st.markdown(highlighted, unsafe_allow_html=True) |
except Exception as e: |
st.error(f"An error occurred while processing the image: {str(e)}") |
else: |
st.info("Please upload an image file to get started.") |
if __name__ == "__main__": |
main() |