kavg commited on
Commit
16f0a9b
1 Parent(s): 75c54a9

commit before changing entity merging process

Browse files
Files changed (4) hide show
  1. main.py +23 -13
  2. ocr.py +4 -4
  3. preprocess.py +9 -0
  4. token_classification.py +1 -1
main.py CHANGED
@@ -44,9 +44,13 @@ async def ProcessDocument(file: UploadFile):
44
  raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
45
  try:
46
  tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
47
- reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
48
  except:
49
- raise HTTPException(status_code=400, detail="Invalid Image")
 
 
 
 
 
50
  return reOutput
51
 
52
  @app.post("/submit-doc-base64")
@@ -78,22 +82,27 @@ def ApplyOCR(content):
78
  except:
79
  raise HTTPException(status_code=400, detail="Handwritting detection failed")
80
 
81
- try:
82
- trocr_client = ocr.TrOCRClient(config['settings'].TROCR_API_URL)
83
- handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
84
- except:
85
- raise HTTPException(status_code=400, detail="handwritten OCR process failed")
86
-
87
  try:
88
  jpeg_bytes = io.BytesIO()
89
- printed_img.save(jpeg_bytes, format='JPEG')
90
- jpeg_content = jpeg_bytes.getvalue()
 
91
  vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
92
- printed_ocr_df = vision_client.ocr(jpeg_content, printed_img)
93
- except:
 
 
94
  raise HTTPException(status_code=400, detail="Printed OCR process failed")
95
 
 
 
 
 
 
 
 
96
  ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
 
97
  return ocr_df, image
98
 
99
 
@@ -103,13 +112,14 @@ def LabelTokens(ocr_df, image):
103
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
104
 
105
  def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
 
106
  token_labels = tokenClassificationOutput['token_labels']
107
  input_ids = tokenClassificationOutput['input_ids']
108
  attention_mask = tokenClassificationOutput["attention_mask"]
109
  bbox_org = tokenClassificationOutput["bbox"]
110
 
111
  merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
112
-
113
  entities = merged_output['entities']
114
  input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
115
  bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
 
44
  raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
45
  try:
46
  tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
 
47
  except:
48
+ raise HTTPException(status_code=400, detail="Entity identification failed")
49
+
50
+ try:
51
+ reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
52
+ except:
53
+ raise HTTPException(status_code=400, detail="Relation extraction failed")
54
  return reOutput
55
 
56
  @app.post("/submit-doc-base64")
 
82
  except:
83
  raise HTTPException(status_code=400, detail="Handwritting detection failed")
84
 
 
 
 
 
 
 
85
  try:
86
  jpeg_bytes = io.BytesIO()
87
+ printed_img.save(jpeg_bytes, format='PNG')
88
+ # printed_img.save('temp/printed_text_image.jpeg', format='PNG')
89
+ printed_content = jpeg_bytes.getvalue()
90
  vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
91
+ printed_ocr_df = vision_client.ocr(printed_content, printed_img)
92
+ # printed_ocr_df.to_csv('temp/complete_image_ocr.csv', index=False)
93
+ # return printed_ocr_df, image
94
+ except Exception as e:
95
  raise HTTPException(status_code=400, detail="Printed OCR process failed")
96
 
97
+ try:
98
+ trocr_client = ocr.TrOCRClient(config['settings'].TROCR_API_URL)
99
+ handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
100
+ except Exception as e:
101
+ print(e)
102
+ raise HTTPException(status_code=400, detail="handwritten OCR process failed")
103
+
104
  ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
105
+ # ocr_df = printed_ocr_df
106
  return ocr_df, image
107
 
108
 
 
112
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
113
 
114
  def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
115
+ print(tokenClassificationOutput)
116
  token_labels = tokenClassificationOutput['token_labels']
117
  input_ids = tokenClassificationOutput['input_ids']
118
  attention_mask = tokenClassificationOutput["attention_mask"]
119
  bbox_org = tokenClassificationOutput["bbox"]
120
 
121
  merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
122
+
123
  entities = merged_output['entities']
124
  input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
125
  bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
ocr.py CHANGED
@@ -1,12 +1,11 @@
1
  from google.cloud import vision
2
  from google.oauth2 import service_account
3
- from google.protobuf.json_format import MessageToJson
4
  import pandas as pd
5
  import json
6
  import numpy as np
7
- from PIL import Image
8
  import io
9
  import requests
 
10
 
11
  image_ext = ("*.jpg", "*.jpeg", "*.png")
12
 
@@ -23,7 +22,7 @@ class VisionClient:
23
  except ValueError as e:
24
  print("Image could not be read")
25
  return
26
- response = self.client.document_text_detection(image, timeout=10)
27
  return response
28
 
29
  def get_response(self, content):
@@ -134,7 +133,8 @@ class TrOCRClient():
134
  boxObjects = []
135
  for i in range(len(handwritten_imgs)):
136
  handwritten_img = handwritten_imgs[i]
137
- ocr_result = self.send_request(handwritten_img[0])
 
138
  boxObjects.append({
139
  "id": i-1,
140
  "text": ocr_result,
 
1
  from google.cloud import vision
2
  from google.oauth2 import service_account
 
3
  import pandas as pd
4
  import json
5
  import numpy as np
 
6
  import io
7
  import requests
8
+ from preprocess import cam_scanner_filter
9
 
10
  image_ext = ("*.jpg", "*.jpeg", "*.png")
11
 
 
22
  except ValueError as e:
23
  print("Image could not be read")
24
  return
25
+ response = self.client.document_text_detection(image, timeout=60)
26
  return response
27
 
28
  def get_response(self, content):
 
133
  boxObjects = []
134
  for i in range(len(handwritten_imgs)):
135
  handwritten_img = handwritten_imgs[i]
136
+ handwritten_img_processed = cam_scanner_filter(handwritten_img[0])
137
+ ocr_result = self.send_request(handwritten_img_processed)
138
  boxObjects.append({
139
  "id": i-1,
140
  "text": ocr_result,
preprocess.py CHANGED
@@ -1,5 +1,8 @@
1
  import torch
2
  from transformers import AutoTokenizer
 
 
 
3
 
4
  def normalize_box(box, width, height):
5
  return [
@@ -9,6 +12,12 @@ def normalize_box(box, width, height):
9
  int(1000 * (box[3] / height)),
10
  ]
11
 
 
 
 
 
 
 
12
  # class to turn the keys of a dict into attributes (thanks Stackoverflow)
13
  class AttrDict(dict):
14
  def __init__(self, *args, **kwargs):
 
1
  import torch
2
  from transformers import AutoTokenizer
3
+ import cv2
4
+ from PIL import Image
5
+ import numpy as np
6
 
7
  def normalize_box(box, width, height):
8
  return [
 
12
  int(1000 * (box[3] / height)),
13
  ]
14
 
15
+ def cam_scanner_filter(img):
16
+ image1 = np.array(img)
17
+ img = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
18
+ thresh2 = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY, 199, 15)
19
+ return Image.fromarray(thresh2)
20
+
21
  # class to turn the keys of a dict into attributes (thanks Stackoverflow)
22
  class AttrDict(dict):
23
  def __init__(self, *args, **kwargs):
token_classification.py CHANGED
@@ -198,7 +198,7 @@ def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, b
198
 
199
  merged_words = []
200
  for i,merged_tagging in enumerate(merged_taggings):
201
- if len(merged_tagging) > 1:
202
  new_word = {}
203
  merging_word = " ".join([word['text'] for word in merged_tagging])
204
  merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
 
198
 
199
  merged_words = []
200
  for i,merged_tagging in enumerate(merged_taggings):
201
+ if ((len(merged_tagging) > 1) or (merged_tagging['label']=='ANSWER')):
202
  new_word = {}
203
  merging_word = " ".join([word['text'] for word in merged_tagging])
204
  merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]