commit before changing entity merging process
Browse files- main.py +23 -13
- ocr.py +4 -4
- preprocess.py +9 -0
- token_classification.py +1 -1
main.py
CHANGED
@@ -44,9 +44,13 @@ async def ProcessDocument(file: UploadFile):
|
|
44 |
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
|
45 |
try:
|
46 |
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
|
47 |
-
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
48 |
except:
|
49 |
-
raise HTTPException(status_code=400, detail="
|
|
|
|
|
|
|
|
|
|
|
50 |
return reOutput
|
51 |
|
52 |
@app.post("/submit-doc-base64")
|
@@ -78,22 +82,27 @@ def ApplyOCR(content):
|
|
78 |
except:
|
79 |
raise HTTPException(status_code=400, detail="Handwritting detection failed")
|
80 |
|
81 |
-
try:
|
82 |
-
trocr_client = ocr.TrOCRClient(config['settings'].TROCR_API_URL)
|
83 |
-
handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
|
84 |
-
except:
|
85 |
-
raise HTTPException(status_code=400, detail="handwritten OCR process failed")
|
86 |
-
|
87 |
try:
|
88 |
jpeg_bytes = io.BytesIO()
|
89 |
-
printed_img.save(jpeg_bytes, format='
|
90 |
-
|
|
|
91 |
vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
|
92 |
-
printed_ocr_df = vision_client.ocr(
|
93 |
-
|
|
|
|
|
94 |
raise HTTPException(status_code=400, detail="Printed OCR process failed")
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
|
|
|
97 |
return ocr_df, image
|
98 |
|
99 |
|
@@ -103,13 +112,14 @@ def LabelTokens(ocr_df, image):
|
|
103 |
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
|
104 |
|
105 |
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
|
|
106 |
token_labels = tokenClassificationOutput['token_labels']
|
107 |
input_ids = tokenClassificationOutput['input_ids']
|
108 |
attention_mask = tokenClassificationOutput["attention_mask"]
|
109 |
bbox_org = tokenClassificationOutput["bbox"]
|
110 |
|
111 |
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
|
112 |
-
|
113 |
entities = merged_output['entities']
|
114 |
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
|
115 |
bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
|
|
|
44 |
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
|
45 |
try:
|
46 |
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
|
|
|
47 |
except:
|
48 |
+
raise HTTPException(status_code=400, detail="Entity identification failed")
|
49 |
+
|
50 |
+
try:
|
51 |
+
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
52 |
+
except:
|
53 |
+
raise HTTPException(status_code=400, detail="Relation extraction failed")
|
54 |
return reOutput
|
55 |
|
56 |
@app.post("/submit-doc-base64")
|
|
|
82 |
except:
|
83 |
raise HTTPException(status_code=400, detail="Handwritting detection failed")
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
try:
|
86 |
jpeg_bytes = io.BytesIO()
|
87 |
+
printed_img.save(jpeg_bytes, format='PNG')
|
88 |
+
# printed_img.save('temp/printed_text_image.jpeg', format='PNG')
|
89 |
+
printed_content = jpeg_bytes.getvalue()
|
90 |
vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
|
91 |
+
printed_ocr_df = vision_client.ocr(printed_content, printed_img)
|
92 |
+
# printed_ocr_df.to_csv('temp/complete_image_ocr.csv', index=False)
|
93 |
+
# return printed_ocr_df, image
|
94 |
+
except Exception as e:
|
95 |
raise HTTPException(status_code=400, detail="Printed OCR process failed")
|
96 |
|
97 |
+
try:
|
98 |
+
trocr_client = ocr.TrOCRClient(config['settings'].TROCR_API_URL)
|
99 |
+
handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
|
100 |
+
except Exception as e:
|
101 |
+
print(e)
|
102 |
+
raise HTTPException(status_code=400, detail="handwritten OCR process failed")
|
103 |
+
|
104 |
ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
|
105 |
+
# ocr_df = printed_ocr_df
|
106 |
return ocr_df, image
|
107 |
|
108 |
|
|
|
112 |
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
|
113 |
|
114 |
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
115 |
+
print(tokenClassificationOutput)
|
116 |
token_labels = tokenClassificationOutput['token_labels']
|
117 |
input_ids = tokenClassificationOutput['input_ids']
|
118 |
attention_mask = tokenClassificationOutput["attention_mask"]
|
119 |
bbox_org = tokenClassificationOutput["bbox"]
|
120 |
|
121 |
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
|
122 |
+
|
123 |
entities = merged_output['entities']
|
124 |
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
|
125 |
bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
|
ocr.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
from google.cloud import vision
|
2 |
from google.oauth2 import service_account
|
3 |
-
from google.protobuf.json_format import MessageToJson
|
4 |
import pandas as pd
|
5 |
import json
|
6 |
import numpy as np
|
7 |
-
from PIL import Image
|
8 |
import io
|
9 |
import requests
|
|
|
10 |
|
11 |
image_ext = ("*.jpg", "*.jpeg", "*.png")
|
12 |
|
@@ -23,7 +22,7 @@ class VisionClient:
|
|
23 |
except ValueError as e:
|
24 |
print("Image could not be read")
|
25 |
return
|
26 |
-
response = self.client.document_text_detection(image, timeout=
|
27 |
return response
|
28 |
|
29 |
def get_response(self, content):
|
@@ -134,7 +133,8 @@ class TrOCRClient():
|
|
134 |
boxObjects = []
|
135 |
for i in range(len(handwritten_imgs)):
|
136 |
handwritten_img = handwritten_imgs[i]
|
137 |
-
|
|
|
138 |
boxObjects.append({
|
139 |
"id": i-1,
|
140 |
"text": ocr_result,
|
|
|
1 |
from google.cloud import vision
|
2 |
from google.oauth2 import service_account
|
|
|
3 |
import pandas as pd
|
4 |
import json
|
5 |
import numpy as np
|
|
|
6 |
import io
|
7 |
import requests
|
8 |
+
from preprocess import cam_scanner_filter
|
9 |
|
10 |
image_ext = ("*.jpg", "*.jpeg", "*.png")
|
11 |
|
|
|
22 |
except ValueError as e:
|
23 |
print("Image could not be read")
|
24 |
return
|
25 |
+
response = self.client.document_text_detection(image, timeout=60)
|
26 |
return response
|
27 |
|
28 |
def get_response(self, content):
|
|
|
133 |
boxObjects = []
|
134 |
for i in range(len(handwritten_imgs)):
|
135 |
handwritten_img = handwritten_imgs[i]
|
136 |
+
handwritten_img_processed = cam_scanner_filter(handwritten_img[0])
|
137 |
+
ocr_result = self.send_request(handwritten_img_processed)
|
138 |
boxObjects.append({
|
139 |
"id": i-1,
|
140 |
"text": ocr_result,
|
preprocess.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer
|
|
|
|
|
|
|
3 |
|
4 |
def normalize_box(box, width, height):
|
5 |
return [
|
@@ -9,6 +12,12 @@ def normalize_box(box, width, height):
|
|
9 |
int(1000 * (box[3] / height)),
|
10 |
]
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# class to turn the keys of a dict into attributes (thanks Stackoverflow)
|
13 |
class AttrDict(dict):
|
14 |
def __init__(self, *args, **kwargs):
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer
|
3 |
+
import cv2
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
|
7 |
def normalize_box(box, width, height):
|
8 |
return [
|
|
|
12 |
int(1000 * (box[3] / height)),
|
13 |
]
|
14 |
|
15 |
+
def cam_scanner_filter(img):
|
16 |
+
image1 = np.array(img)
|
17 |
+
img = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
|
18 |
+
thresh2 = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY, 199, 15)
|
19 |
+
return Image.fromarray(thresh2)
|
20 |
+
|
21 |
# class to turn the keys of a dict into attributes (thanks Stackoverflow)
|
22 |
class AttrDict(dict):
|
23 |
def __init__(self, *args, **kwargs):
|
token_classification.py
CHANGED
@@ -198,7 +198,7 @@ def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, b
|
|
198 |
|
199 |
merged_words = []
|
200 |
for i,merged_tagging in enumerate(merged_taggings):
|
201 |
-
if len(merged_tagging) > 1:
|
202 |
new_word = {}
|
203 |
merging_word = " ".join([word['text'] for word in merged_tagging])
|
204 |
merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
|
|
|
198 |
|
199 |
merged_words = []
|
200 |
for i,merged_tagging in enumerate(merged_taggings):
|
201 |
+
if ((len(merged_tagging) > 1) or (merged_tagging['label']=='ANSWER')):
|
202 |
new_word = {}
|
203 |
merging_word = " ".join([word['text'] for word in merged_tagging])
|
204 |
merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
|