fixed entity merging issue
Browse files- handwritting_detection.py +1 -1
- main.py +9 -4
- token_classification.py +1 -2
handwritting_detection.py
CHANGED
@@ -36,6 +36,6 @@ def DetectHandwritting(image):
|
|
36 |
cpy = image.copy()
|
37 |
handwritten_parts = []
|
38 |
for prediction in result['predictions']:
|
39 |
-
cpy = draw_rectangle(cpy, **prediction)
|
40 |
handwritten_parts.append(crop_image(cpy, **prediction))
|
|
|
41 |
return cpy, handwritten_parts
|
|
|
36 |
cpy = image.copy()
|
37 |
handwritten_parts = []
|
38 |
for prediction in result['predictions']:
|
|
|
39 |
handwritten_parts.append(crop_image(cpy, **prediction))
|
40 |
+
cpy = draw_rectangle(cpy, **prediction)
|
41 |
return cpy, handwritten_parts
|
main.py
CHANGED
@@ -49,7 +49,8 @@ async def ProcessDocument(file: UploadFile):
|
|
49 |
|
50 |
try:
|
51 |
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
52 |
-
except:
|
|
|
53 |
raise HTTPException(status_code=400, detail="Relation extraction failed")
|
54 |
return reOutput
|
55 |
|
@@ -66,9 +67,14 @@ async def ProcessDocument(file: str = Form(...)):
|
|
66 |
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
|
67 |
try:
|
68 |
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
|
|
|
|
|
|
|
|
|
69 |
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
70 |
-
except:
|
71 |
-
|
|
|
72 |
return reOutput
|
73 |
|
74 |
def ApplyOCR(content):
|
@@ -112,7 +118,6 @@ def LabelTokens(ocr_df, image):
|
|
112 |
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
|
113 |
|
114 |
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
115 |
-
print(tokenClassificationOutput)
|
116 |
token_labels = tokenClassificationOutput['token_labels']
|
117 |
input_ids = tokenClassificationOutput['input_ids']
|
118 |
attention_mask = tokenClassificationOutput["attention_mask"]
|
|
|
49 |
|
50 |
try:
|
51 |
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
52 |
+
except Exception as e:
|
53 |
+
print(e)
|
54 |
raise HTTPException(status_code=400, detail="Relation extraction failed")
|
55 |
return reOutput
|
56 |
|
|
|
67 |
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
|
68 |
try:
|
69 |
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
|
70 |
+
except Exception as e:
|
71 |
+
print(e)
|
72 |
+
raise HTTPException(status_code=400, detail="Entity identification failed")
|
73 |
+
try:
|
74 |
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
75 |
+
except Exception as e:
|
76 |
+
print(e)
|
77 |
+
raise HTTPException(status_code=400, detail="Relation extraction failed")
|
78 |
return reOutput
|
79 |
|
80 |
def ApplyOCR(content):
|
|
|
118 |
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
|
119 |
|
120 |
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
|
|
121 |
token_labels = tokenClassificationOutput['token_labels']
|
122 |
input_ids = tokenClassificationOutput['input_ids']
|
123 |
attention_mask = tokenClassificationOutput["attention_mask"]
|
token_classification.py
CHANGED
@@ -195,10 +195,9 @@ def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, b
|
|
195 |
for i in range(len(df_words)):
|
196 |
if i not in visited:
|
197 |
merged_taggings.append(dfs(i,[], width, height, visited, df_words))
|
198 |
-
|
199 |
merged_words = []
|
200 |
for i,merged_tagging in enumerate(merged_taggings):
|
201 |
-
if ((len(merged_tagging) > 1) or (merged_tagging['label']=='ANSWER')
|
202 |
new_word = {}
|
203 |
merging_word = " ".join([word['text'] for word in merged_tagging])
|
204 |
merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
|
|
|
195 |
for i in range(len(df_words)):
|
196 |
if i not in visited:
|
197 |
merged_taggings.append(dfs(i,[], width, height, visited, df_words))
|
|
|
198 |
merged_words = []
|
199 |
for i,merged_tagging in enumerate(merged_taggings):
|
200 |
+
if ((len(merged_tagging) > 1)) or (merged_tagging[0]['label'] == 'ANSWER'):
|
201 |
new_word = {}
|
202 |
merging_word = " ".join([word['text'] for word in merged_tagging])
|
203 |
merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
|