|
import streamlit as st |
|
import torch |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
from transformers import AutoProcessor |
|
import sys |
|
import os |
|
import importlib.util |
|
from huggingface_hub import snapshot_download |
|
import easyocr |
|
import re |
|
from typing import Tuple |
|
import numpy as np |
|
|
|
|
|
torch.cuda.is_available = lambda: False |
|
|
|
|
|
st.set_page_config(page_title="Bilingual OCR App", layout="wide") |
|
|
|
@st.cache_resource |
|
def setup_got_model() -> Tuple[object, object]: |
|
st.info("Using CPU for computation. This may take longer for processing.") |
|
|
|
|
|
model_path = snapshot_download("ucaslcl/GOT-OCR2_0", revision="main") |
|
|
|
|
|
spec = importlib.util.spec_from_file_location("modeling_GOT", os.path.join(model_path, "modeling_GOT.py")) |
|
modeling_GOT = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(modeling_GOT) |
|
|
|
GOTVisionT5ForConditionalGeneration = modeling_GOT.GOTVisionT5ForConditionalGeneration |
|
|
|
|
|
model = GOTVisionT5ForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True) |
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
return model, processor |
|
|
|
@st.cache_resource |
|
def setup_easyocr() -> easyocr.Reader: |
|
|
|
reader = easyocr.Reader(['en', 'hi'], gpu=False) |
|
return reader |
|
|
|
def perform_got_ocr(model, processor, image: Image.Image) -> str: |
|
inputs = processor(images=image, return_tensors="pt") |
|
generated_ids = model.generate(**inputs, max_length=512) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
def perform_easyocr(ocr_reader, image: np.ndarray) -> str: |
|
results = ocr_reader.readtext(image, detail=0, paragraph=True) |
|
extracted_text = '\n'.join(results) |
|
return extracted_text |
|
|
|
def highlight_keywords(text: str, keywords: str) -> str: |
|
|
|
escaped_keywords = re.escape(keywords) |
|
pattern = re.compile(f"({escaped_keywords})", re.IGNORECASE) |
|
highlighted_text = pattern.sub(r'<mark>\1</mark>', text) |
|
return highlighted_text |
|
|
|
def main(): |
|
|
|
try: |
|
with st.spinner("Setting up GOT model... This may take a few minutes."): |
|
model, processor = setup_got_model() |
|
ocr_model = 'GOT-OCR2_0' |
|
except Exception as e: |
|
st.warning(f"GOT model failed to load: {str(e)}. Falling back to EasyOCR.") |
|
model, processor = None, None |
|
ocr_reader = setup_easyocr() |
|
ocr_model = 'EasyOCR' |
|
|
|
st.title("Bilingual OCR Application") |
|
|
|
|
|
st.sidebar.header("Instructions") |
|
st.sidebar.markdown(""" |
|
1. **Upload an image** containing text in Hindi and/or English. |
|
2. **Extracted text** will be displayed below. |
|
3. **Enter keywords** to search within the extracted text. |
|
4. **Matching sections** will be highlighted in the results. |
|
""") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an Image for OCR", type=['png', 'jpg', 'jpeg']) |
|
|
|
if uploaded_file is not None: |
|
try: |
|
|
|
image_bytes = uploaded_file.read() |
|
|
|
|
|
image_pil = Image.open(BytesIO(image_bytes)).convert('RGB') |
|
|
|
|
|
image_np = np.array(image_pil) |
|
|
|
st.image(image_pil, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
with st.spinner(f"Performing OCR using {ocr_model}..."): |
|
if ocr_model == 'GOT-OCR2_0' and model and processor: |
|
extracted_text = perform_got_ocr(model, processor, image_pil) |
|
else: |
|
extracted_text = perform_easyocr(ocr_reader, image_np) |
|
|
|
st.subheader("Extracted Text:") |
|
st.text_area("Text", extracted_text, height=200) |
|
|
|
|
|
keyword = st.text_input("Enter keyword to search in extracted text:") |
|
|
|
if keyword: |
|
highlighted = highlight_keywords(extracted_text, keyword) |
|
st.subheader("Search Results:") |
|
st.markdown(highlighted, unsafe_allow_html=True) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred while processing the image: {str(e)}") |
|
|
|
else: |
|
st.info("Please upload an image file to get started.") |
|
|
|
if __name__ == "__main__": |
|
main() |