OCR_Demo / app.py
Prasada's picture
Update app.py
e3ae3eb verified
raw
history blame
No virus
4.55 kB
import os
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import json
from paddleocr import PaddleOCR
from transformers import pipeline
import gradio as gr
# Initialize PaddleOCR
ocr = PaddleOCR(use_angle_cls=True, lang='en')
# Predefined fields for extraction
FIELDS = ["Scheme Name", "Folio Number", "Number of Units", "PAN", "Signature", "Tax Status",
"Mobile Number", "Email", "Address", "Bank Account Details"]
def draw_boxes_on_image(image, data):
"""
Draw bounding boxes and text on the image.
Args:
image (PIL Image): The input image.
data (list): OCR results containing bounding boxes and detected text.
Returns:
PIL Image: The image with drawn boxes.
"""
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
for item_id, item in enumerate(data, start=1):
bounding_box, (text, confidence) = item
box = np.array(bounding_box).astype(int)
draw.line([tuple(box[0]), tuple(box[1])], fill="green", width=2)
draw.line([tuple(box[1]), tuple(box[2])], fill="green", width=2)
draw.line([tuple(box[2]), tuple(box[3])], fill="green", width=2)
draw.line([tuple(box[3]), tuple(box[0])], fill="green", width=2)
text_position = (box[0][0], box[0][1] - 20)
draw.text(text_position, f"{item_id}: {text} ({confidence:.2f})", fill="red", font=font)
return image
def convert_to_json(results):
"""
Converts the OCR results into a JSON object with bounding box IDs.
Args:
results (list): The list of OCR results containing bounding box coordinates, text, and confidence.
Returns:
dict: JSON data with bounding boxes and text.
"""
json_data = []
for item_id, result in enumerate(results, start=1):
bounding_box = result[0]
text = result[1][0]
confidence = result[1][1]
json_data.append({
"id": item_id,
"bounding_box": bounding_box,
"text": text,
"confidence": confidence
})
return json_data
def extract_field_value_pairs(text):
"""
Extract field-value pairs from the text using a pre-trained NLP model.
Args:
text (str): The text to be processed.
Returns:
dict: A dictionary with field-value pairs.
"""
nlp = pipeline("ner", model="mrm8488/bert-tiny-finetuned-sms-spam-detection")
ner_results = []
chunk_size = 256
for i in range(0, len(text), chunk_size):
chunk = text[i:i+chunk_size]
ner_results.extend(nlp(chunk))
field_value_pairs = {}
current_field = None
for entity in ner_results:
word = entity['word']
for field in FIELDS:
if field.lower() in word.lower():
current_field = field
break
if current_field and entity['entity'] == "LABEL_1":
field_value_pairs[current_field] = word
return field_value_pairs
def process_image(image):
"""
Process the uploaded image and perform OCR.
Args:
image (PIL Image): The input image.
Returns:
tuple: The image with bounding boxes, OCR results in JSON format, and field-value pairs.
"""
# Perform OCR on the image
ocr_results = ocr.ocr(np.array(image), cls=True)
# Draw boxes on the image
image_with_boxes = draw_boxes_on_image(image.copy(), ocr_results[0])
# Convert OCR results to JSON
json_results = convert_to_json(ocr_results[0])
json_results_path = 'ocr_results.json' # Save in the root directory
with open(json_results_path, "w") as f:
json.dump(json_results, f, indent=4)
# Extract field-value pairs from the text
text = " ".join([result[1][0] for result in ocr_results[0]])
field_value_pairs = extract_field_value_pairs(text)
field_value_pairs_path = 'extracted_fields.json' # Save in the root directory
with open(field_value_pairs_path, "w") as f:
json.dump(field_value_pairs, f, indent=4)
return image_with_boxes, json_results_path, field_value_pairs_path
# Define Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil"),
gr.File(label="Download OCR Results"),
gr.File(label="Download Extracted Fields")
],
live=True
)
if __name__ == "__main__":
iface.launch()