kavg commited on
Commit
ccb5ac8
1 Parent(s): 16f0a9b

fixed entity merging issue

Browse files
Files changed (3) hide show
  1. handwritting_detection.py +1 -1
  2. main.py +9 -4
  3. token_classification.py +1 -2
handwritting_detection.py CHANGED
@@ -36,6 +36,6 @@ def DetectHandwritting(image):
36
  cpy = image.copy()
37
  handwritten_parts = []
38
  for prediction in result['predictions']:
39
- cpy = draw_rectangle(cpy, **prediction)
40
  handwritten_parts.append(crop_image(cpy, **prediction))
 
41
  return cpy, handwritten_parts
 
36
  cpy = image.copy()
37
  handwritten_parts = []
38
  for prediction in result['predictions']:
 
39
  handwritten_parts.append(crop_image(cpy, **prediction))
40
+ cpy = draw_rectangle(cpy, **prediction)
41
  return cpy, handwritten_parts
main.py CHANGED
@@ -49,7 +49,8 @@ async def ProcessDocument(file: UploadFile):
49
 
50
  try:
51
  reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
52
- except:
 
53
  raise HTTPException(status_code=400, detail="Relation extraction failed")
54
  return reOutput
55
 
@@ -66,9 +67,14 @@ async def ProcessDocument(file: str = Form(...)):
66
  raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
67
  try:
68
  tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
 
 
 
 
69
  reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
70
- except:
71
- raise HTTPException(status_code=400, detail="Invalid Image")
 
72
  return reOutput
73
 
74
  def ApplyOCR(content):
@@ -112,7 +118,6 @@ def LabelTokens(ocr_df, image):
112
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
113
 
114
  def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
115
- print(tokenClassificationOutput)
116
  token_labels = tokenClassificationOutput['token_labels']
117
  input_ids = tokenClassificationOutput['input_ids']
118
  attention_mask = tokenClassificationOutput["attention_mask"]
 
49
 
50
  try:
51
  reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
52
+ except Exception as e:
53
+ print(e)
54
  raise HTTPException(status_code=400, detail="Relation extraction failed")
55
  return reOutput
56
 
 
67
  raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
68
  try:
69
  tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
70
+ except Exception as e:
71
+ print(e)
72
+ raise HTTPException(status_code=400, detail="Entity identification failed")
73
+ try:
74
  reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
75
+ except Exception as e:
76
+ print(e)
77
+ raise HTTPException(status_code=400, detail="Relation extraction failed")
78
  return reOutput
79
 
80
  def ApplyOCR(content):
 
118
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
119
 
120
  def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
 
121
  token_labels = tokenClassificationOutput['token_labels']
122
  input_ids = tokenClassificationOutput['input_ids']
123
  attention_mask = tokenClassificationOutput["attention_mask"]
token_classification.py CHANGED
@@ -195,10 +195,9 @@ def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, b
195
  for i in range(len(df_words)):
196
  if i not in visited:
197
  merged_taggings.append(dfs(i,[], width, height, visited, df_words))
198
-
199
  merged_words = []
200
  for i,merged_tagging in enumerate(merged_taggings):
201
- if ((len(merged_tagging) > 1) or (merged_tagging['label']=='ANSWER')):
202
  new_word = {}
203
  merging_word = " ".join([word['text'] for word in merged_tagging])
204
  merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
 
195
  for i in range(len(df_words)):
196
  if i not in visited:
197
  merged_taggings.append(dfs(i,[], width, height, visited, df_words))
 
198
  merged_words = []
199
  for i,merged_tagging in enumerate(merged_taggings):
200
+ if ((len(merged_tagging) > 1)) or (merged_tagging[0]['label'] == 'ANSWER'):
201
  new_word = {}
202
  merging_word = " ".join([word['text'] for word in merged_tagging])
203
  merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]