Prasada commited on
Commit
024ddf4
1 Parent(s): 348ac05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -74
app.py CHANGED
@@ -3,119 +3,130 @@ import numpy as np
3
  from PIL import Image, ImageDraw, ImageFont
4
  import json
5
  from paddleocr import PaddleOCR
 
6
  import gradio as gr
7
- import os
8
 
9
  # Initialize PaddleOCR
10
  ocr = PaddleOCR(use_angle_cls=True, lang='en')
11
 
12
- # Function to draw bounding boxes on the image
13
- def draw_boxes_on_image(image, data):
14
- # Convert the image to RGB (OpenCV uses BGR by default)
15
- image_rgb = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
16
-
17
- # Load the image into PIL for easier drawing
18
- pil_image = Image.fromarray(image_rgb)
19
- draw = ImageDraw.Draw(pil_image)
20
 
21
- # Define a font (using DejaVuSans since it's available by default)
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
- font = ImageFont.truetype("DejaVuSans.ttf", 20)
24
  except IOError:
25
  font = ImageFont.load_default()
26
 
27
- for item in data:
28
  bounding_box, (text, confidence) = item
29
-
30
- # Ensure bounding_box is a list of lists
31
- if not isinstance(bounding_box[0], list):
32
- bounding_box = [bounding_box]
33
-
34
  box = np.array(bounding_box).astype(int)
35
-
36
- # Draw the bounding box
37
  draw.line([tuple(box[0]), tuple(box[1])], fill="green", width=2)
38
  draw.line([tuple(box[1]), tuple(box[2])], fill="green", width=2)
39
  draw.line([tuple(box[2]), tuple(box[3])], fill="green", width=2)
40
  draw.line([tuple(box[3]), tuple(box[0])], fill="green", width=2)
41
-
42
- # Draw the text above the bounding box
43
  text_position = (box[0][0], box[0][1] - 20)
44
- draw.text(text_position, f"{text} ({confidence:.2f})", fill="red", font=font)
45
 
46
- return pil_image
47
 
48
- # Function to convert OCR results to JSON
49
- def convert_to_json(results, output_file):
50
  """
51
- Converts the given results into a JSON file.
52
-
53
  Args:
54
- results: The list of results containing bounding box coordinates, text, and confidence.
55
- output_file: The name of the output JSON file.
 
 
56
  """
57
  json_data = []
58
- for result in results:
59
  bounding_box = result[0]
60
  text = result[1][0]
61
  confidence = result[1][1]
62
-
63
  json_data.append({
64
- "bounding_box": [list(map(float, coord)) for coord in bounding_box],
 
65
  "text": text,
66
  "confidence": confidence
67
  })
 
68
 
69
- with open(output_file, "w") as f:
70
- json.dump(json_data, f, indent=4)
71
-
72
- # Function to identify 'field', 'value' pairs
73
- def identify_field_value_pairs(ocr_results, fields):
 
 
 
 
 
 
 
 
 
 
 
 
74
  field_value_pairs = {}
75
- for line in ocr_results:
76
- for word_info in line:
77
- text, _ = word_info[1]
78
- for field in fields:
79
- if field.lower() in text.lower():
80
- # Assuming the value comes immediately after the field
81
- value_index = line.index(word_info) + 1
82
- if value_index < len(line):
83
- field_value_pairs[field] = line[value_index][1][0]
84
- break
85
  return field_value_pairs
86
 
87
- # Function to process the image and generate outputs
88
  def process_image(image):
 
 
 
 
 
 
 
 
 
 
89
  ocr_results = ocr.ocr(np.array(image), cls=True)
90
- processed_image = draw_boxes_on_image(image, ocr_results[0])
91
 
92
- # Save OCR results to JSON
93
- json_path = "ocr_results.json"
94
- convert_to_json(ocr_results[0], json_path)
95
 
96
- # Identify field-value pairs
97
- fields = ["Scheme Name", "Folio Number", "Number of Units", "PAN", "Signature", "Tax Status",
98
- "Mobile Number", "Email", "Address", "Bank Account Details"]
99
- field_value_pairs = identify_field_value_pairs(ocr_results[0], fields)
100
- field_value_json_path = "field_value_pairs.json"
101
 
102
- with open(field_value_json_path, 'w') as json_file:
103
- json.dump(field_value_pairs, json_file, indent=4)
104
-
105
- return processed_image, json_path, field_value_json_path
106
-
107
- # Gradio Interface
108
- interface = gr.Interface(
109
- fn=process_image,
110
- inputs="image",
111
- outputs=[
112
- "image",
113
- gr.File(label="Download OCR Results JSON"),
114
- gr.File(label="Download Field-Value Pairs JSON")
115
- ],
116
- title="OCR Web Application",
117
- description="Upload an image and get OCR results with bounding boxes and two JSON outputs."
118
  )
119
 
120
  if __name__ == "__main__":
121
- interface.launch()
 
3
  from PIL import Image, ImageDraw, ImageFont
4
  import json
5
  from paddleocr import PaddleOCR
6
+ from transformers import pipeline
7
  import gradio as gr
 
8
 
9
  # Initialize PaddleOCR
10
  ocr = PaddleOCR(use_angle_cls=True, lang='en')
11
 
12
+ # Predefined fields for extraction
13
+ FIELDS = ["Scheme Name", "Folio Number", "Number of Units", "PAN", "Signature", "Tax Status",
14
+ "Mobile Number", "Email", "Address", "Bank Account Details"]
 
 
 
 
 
15
 
16
+ def draw_boxes_on_image(image, data):
17
+ """
18
+ Draw bounding boxes and text on the image.
19
+
20
+ Args:
21
+ image (PIL Image): The input image.
22
+ data (list): OCR results containing bounding boxes and detected text.
23
+
24
+ Returns:
25
+ PIL Image: The image with drawn boxes.
26
+ """
27
+ draw = ImageDraw.Draw(image)
28
  try:
29
+ font = ImageFont.truetype("arial.ttf", 20)
30
  except IOError:
31
  font = ImageFont.load_default()
32
 
33
+ for item_id, item in enumerate(data, start=1):
34
  bounding_box, (text, confidence) = item
 
 
 
 
 
35
  box = np.array(bounding_box).astype(int)
 
 
36
  draw.line([tuple(box[0]), tuple(box[1])], fill="green", width=2)
37
  draw.line([tuple(box[1]), tuple(box[2])], fill="green", width=2)
38
  draw.line([tuple(box[2]), tuple(box[3])], fill="green", width=2)
39
  draw.line([tuple(box[3]), tuple(box[0])], fill="green", width=2)
 
 
40
  text_position = (box[0][0], box[0][1] - 20)
41
+ draw.text(text_position, f"{item_id}: {text} ({confidence:.2f})", fill="red", font=font)
42
 
43
+ return image
44
 
45
+ def convert_to_json(results):
 
46
  """
47
+ Converts the OCR results into a JSON object with bounding box IDs.
48
+
49
  Args:
50
+ results (list): The list of OCR results containing bounding box coordinates, text, and confidence.
51
+
52
+ Returns:
53
+ dict: JSON data with bounding boxes and text.
54
  """
55
  json_data = []
56
+ for item_id, result in enumerate(results, start=1):
57
  bounding_box = result[0]
58
  text = result[1][0]
59
  confidence = result[1][1]
 
60
  json_data.append({
61
+ "id": item_id,
62
+ "bounding_box": bounding_box,
63
  "text": text,
64
  "confidence": confidence
65
  })
66
+ return json_data
67
 
68
+ def extract_field_value_pairs(text):
69
+ """
70
+ Extract field-value pairs from the text using a pre-trained NLP model.
71
+
72
+ Args:
73
+ text (str): The text to be processed.
74
+
75
+ Returns:
76
+ dict: A dictionary with field-value pairs.
77
+ """
78
+ nlp = pipeline("ner", model="mrm8488/bert-tiny-finetuned-sms-spam-detection")
79
+ ner_results = []
80
+ chunk_size = 256
81
+ for i in range(0, len(text), chunk_size):
82
+ chunk = text[i:i+chunk_size]
83
+ ner_results.extend(nlp(chunk))
84
+
85
  field_value_pairs = {}
86
+ current_field = None
87
+ for entity in ner_results:
88
+ word = entity['word']
89
+ for field in FIELDS:
90
+ if field.lower() in word.lower():
91
+ current_field = field
92
+ break
93
+ if current_field and entity['entity'] == "LABEL_1":
94
+ field_value_pairs[current_field] = word
95
+
96
  return field_value_pairs
97
 
 
98
  def process_image(image):
99
+ """
100
+ Process the uploaded image and perform OCR.
101
+
102
+ Args:
103
+ image (PIL Image): The input image.
104
+
105
+ Returns:
106
+ tuple: The image with bounding boxes, OCR results in JSON format, and field-value pairs.
107
+ """
108
+ # Perform OCR on the image
109
  ocr_results = ocr.ocr(np.array(image), cls=True)
 
110
 
111
+ # Draw boxes on the image
112
+ image_with_boxes = draw_boxes_on_image(image.copy(), ocr_results[0])
 
113
 
114
+ # Convert OCR results to JSON
115
+ json_results = convert_to_json(ocr_results[0])
 
 
 
116
 
117
+ # Extract field-value pairs from the text
118
+ text = " ".join([result[1][0] for result in ocr_results[0]])
119
+ field_value_pairs = extract_field_value_pairs(text)
120
+
121
+ return image_with_boxes, json_results, field_value_pairs
122
+
123
+ # Define Gradio interface
124
+ iface = gr.Interface(
125
+ fn=process_image,
126
+ inputs=gr.Image(type="pil"),
127
+ outputs=[gr.Image(type="pil"), gr.JSON(), gr.JSON()],
128
+ live=True
 
 
 
 
129
  )
130
 
131
  if __name__ == "__main__":
132
+ iface.launch()