|
import cv2 |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import json |
|
from paddleocr import PaddleOCR |
|
from transformers import pipeline |
|
import gradio as gr |
|
|
|
|
|
ocr = PaddleOCR(use_angle_cls=True, lang='en') |
|
|
|
|
|
FIELDS = ["Scheme Name", "Folio Number", "Number of Units", "PAN", "Signature", "Tax Status", |
|
"Mobile Number", "Email", "Address", "Bank Account Details"] |
|
|
|
def draw_boxes_on_image(image, data): |
|
""" |
|
Draw bounding boxes and text on the image. |
|
|
|
Args: |
|
image (PIL Image): The input image. |
|
data (list): OCR results containing bounding boxes and detected text. |
|
|
|
Returns: |
|
PIL Image: The image with drawn boxes. |
|
""" |
|
draw = ImageDraw.Draw(image) |
|
try: |
|
font = ImageFont.truetype("arial.ttf", 20) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
|
|
for item_id, item in enumerate(data, start=1): |
|
bounding_box, (text, confidence) = item |
|
box = np.array(bounding_box).astype(int) |
|
draw.line([tuple(box[0]), tuple(box[1])], fill="green", width=2) |
|
draw.line([tuple(box[1]), tuple(box[2])], fill="green", width=2) |
|
draw.line([tuple(box[2]), tuple(box[3])], fill="green", width=2) |
|
draw.line([tuple(box[3]), tuple(box[0])], fill="green", width=2) |
|
text_position = (box[0][0], box[0][1] - 20) |
|
draw.text(text_position, f"{item_id}: {text} ({confidence:.2f})", fill="red", font=font) |
|
|
|
return image |
|
|
|
def convert_to_json(results): |
|
""" |
|
Converts the OCR results into a JSON object with bounding box IDs. |
|
|
|
Args: |
|
results (list): The list of OCR results containing bounding box coordinates, text, and confidence. |
|
|
|
Returns: |
|
dict: JSON data with bounding boxes and text. |
|
""" |
|
json_data = [] |
|
for item_id, result in enumerate(results, start=1): |
|
bounding_box = result[0] |
|
text = result[1][0] |
|
confidence = result[1][1] |
|
json_data.append({ |
|
"id": item_id, |
|
"bounding_box": bounding_box, |
|
"text": text, |
|
"confidence": confidence |
|
}) |
|
return json_data |
|
|
|
def extract_field_value_pairs(text): |
|
""" |
|
Extract field-value pairs from the text using a pre-trained NLP model. |
|
|
|
Args: |
|
text (str): The text to be processed. |
|
|
|
Returns: |
|
dict: A dictionary with field-value pairs. |
|
""" |
|
nlp = pipeline("ner", model="mrm8488/bert-tiny-finetuned-sms-spam-detection") |
|
ner_results = [] |
|
chunk_size = 256 |
|
for i in range(0, len(text), chunk_size): |
|
chunk = text[i:i+chunk_size] |
|
ner_results.extend(nlp(chunk)) |
|
|
|
field_value_pairs = {} |
|
current_field = None |
|
for entity in ner_results: |
|
word = entity['word'] |
|
for field in FIELDS: |
|
if field.lower() in word.lower(): |
|
current_field = field |
|
break |
|
if current_field and entity['entity'] == "LABEL_1": |
|
field_value_pairs[current_field] = word |
|
|
|
return field_value_pairs |
|
|
|
def process_image(image): |
|
""" |
|
Process the uploaded image and perform OCR. |
|
|
|
Args: |
|
image (PIL Image): The input image. |
|
|
|
Returns: |
|
tuple: The image with bounding boxes, OCR results in JSON format, and field-value pairs. |
|
""" |
|
|
|
ocr_results = ocr.ocr(np.array(image), cls=True) |
|
|
|
|
|
image_with_boxes = draw_boxes_on_image(image.copy(), ocr_results[0]) |
|
|
|
|
|
json_results = convert_to_json(ocr_results[0]) |
|
json_results_path = os.path.join(UPLOAD_DIR, 'ocr_results.json') |
|
with open(json_results_path, "w") as f: |
|
json.dump(json_results, f, indent=4) |
|
|
|
|
|
text = " ".join([result[1][0] for result in ocr_results[0]]) |
|
field_value_pairs = extract_field_value_pairs(text) |
|
field_value_pairs_path = os.path.join(UPLOAD_DIR, 'extracted_fields.json') |
|
with open(field_value_pairs_path, "w") as f: |
|
json.dump(field_value_pairs, f, indent=4) |
|
|
|
return image_with_boxes, json_results_path, field_value_pairs_path |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[ |
|
gr.Image(type="pil"), |
|
gr.File(label="Download OCR Results"), |
|
gr.File(label="Download Extracted Fields") |
|
], |
|
live=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |