ocr / app.py
yashnd's picture
Update app.py
0149895 verified
import streamlit as st
import torch
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoProcessor
import sys
import os
import importlib.util
from huggingface_hub import snapshot_download
import easyocr
import re
from typing import Tuple
import numpy as np
# Disable CUDA (Optional: as per your original code)
torch.cuda.is_available = lambda: False
# Set Streamlit page configuration
st.set_page_config(page_title="Bilingual OCR App", layout="wide")
@st.cache_resource
def setup_got_model() -> Tuple[object, object]:
st.info("Using CPU for computation. This may take longer for processing.")
# Download model files
model_path = snapshot_download("ucaslcl/GOT-OCR2_0", revision="main")
# Dynamically import the custom model class
spec = importlib.util.spec_from_file_location("modeling_GOT", os.path.join(model_path, "modeling_GOT.py"))
modeling_GOT = importlib.util.module_from_spec(spec)
spec.loader.exec_module(modeling_GOT)
GOTVisionT5ForConditionalGeneration = modeling_GOT.GOTVisionT5ForConditionalGeneration
# Initialize the model and processor
model = GOTVisionT5ForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
return model, processor
@st.cache_resource
def setup_easyocr() -> easyocr.Reader:
# Initialize EasyOCR with Hindi and English languages
reader = easyocr.Reader(['en', 'hi'], gpu=False)
return reader
def perform_got_ocr(model, processor, image: Image.Image) -> str:
inputs = processor(images=image, return_tensors="pt")
generated_ids = model.generate(**inputs, max_length=512)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
def perform_easyocr(ocr_reader, image: np.ndarray) -> str:
results = ocr_reader.readtext(image, detail=0, paragraph=True)
extracted_text = '\n'.join(results)
return extracted_text
def highlight_keywords(text: str, keywords: str) -> str:
# Escape keywords for regex
escaped_keywords = re.escape(keywords)
pattern = re.compile(f"({escaped_keywords})", re.IGNORECASE)
highlighted_text = pattern.sub(r'<mark>\1</mark>', text)
return highlighted_text
def main():
# Initialize models
try:
with st.spinner("Setting up GOT model... This may take a few minutes."):
model, processor = setup_got_model()
ocr_model = 'GOT-OCR2_0'
except Exception as e:
st.warning(f"GOT model failed to load: {str(e)}. Falling back to EasyOCR.")
model, processor = None, None
ocr_reader = setup_easyocr()
ocr_model = 'EasyOCR'
st.title("Bilingual OCR Application")
# Sidebar for instructions
st.sidebar.header("Instructions")
st.sidebar.markdown("""
1. **Upload an image** containing text in Hindi and/or English.
2. **Extracted text** will be displayed below.
3. **Enter keywords** to search within the extracted text.
4. **Matching sections** will be highlighted in the results.
""")
# File uploader
uploaded_file = st.file_uploader("Upload an Image for OCR", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
try:
# Read the file into bytes
image_bytes = uploaded_file.read()
# Open the image with PIL
image_pil = Image.open(BytesIO(image_bytes)).convert('RGB')
# Convert to numpy array for EasyOCR
image_np = np.array(image_pil)
st.image(image_pil, caption="Uploaded Image", use_column_width=True)
# Perform OCR
with st.spinner(f"Performing OCR using {ocr_model}..."):
if ocr_model == 'GOT-OCR2_0' and model and processor:
extracted_text = perform_got_ocr(model, processor, image_pil)
else:
extracted_text = perform_easyocr(ocr_reader, image_np)
st.subheader("Extracted Text:")
st.text_area("Text", extracted_text, height=200)
# Keyword search
keyword = st.text_input("Enter keyword to search in extracted text:")
if keyword:
highlighted = highlight_keywords(extracted_text, keyword)
st.subheader("Search Results:")
st.markdown(highlighted, unsafe_allow_html=True)
except Exception as e:
st.error(f"An error occurred while processing the image: {str(e)}")
else:
st.info("Please upload an image file to get started.")
if __name__ == "__main__":
main()