added ned entity merging method. Included additional outputs in respons to match frontend streamlit app
Browse files- main.py +37 -16
- models.py +6 -1
- preprocess.py +8 -8
- token_classification.py +212 -28
main.py
CHANGED
@@ -2,7 +2,7 @@ from config import Settings
|
|
2 |
from preprocess import Preprocessor
|
3 |
import ocr
|
4 |
from PIL import Image
|
5 |
-
from transformers import LiltForTokenClassification
|
6 |
import token_classification
|
7 |
import torch
|
8 |
from fastapi import FastAPI, UploadFile
|
@@ -19,6 +19,7 @@ async def lifespan(app: FastAPI):
|
|
19 |
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
|
21 |
config['processor'] = Preprocessor(settings.TOKENIZER)
|
|
|
22 |
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
|
23 |
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
|
24 |
yield
|
@@ -29,8 +30,8 @@ app = FastAPI(lifespan=lifespan)
|
|
29 |
|
30 |
@app.post("/submit-doc")
|
31 |
async def ProcessDocument(file: UploadFile):
|
32 |
-
tokenClassificationOutput = await LabelTokens(file)
|
33 |
-
reOutput = ExtractRelations(tokenClassificationOutput)
|
34 |
return reOutput
|
35 |
|
36 |
async def LabelTokens(file):
|
@@ -39,28 +40,48 @@ async def LabelTokens(file):
|
|
39 |
ocr_df = config['vision_client'].ocr(content, image)
|
40 |
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
41 |
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
|
42 |
-
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "
|
43 |
|
44 |
-
def ExtractRelations(tokenClassificationOutput):
|
45 |
token_labels = tokenClassificationOutput['token_labels']
|
46 |
input_ids = tokenClassificationOutput['input_ids']
|
47 |
-
offset_mapping = tokenClassificationOutput["offset_mapping"]
|
48 |
attention_mask = tokenClassificationOutput["attention_mask"]
|
49 |
-
|
50 |
|
51 |
-
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
config['re_model'].to(config['device'])
|
54 |
-
entity_dict = {'start': [entity[
|
55 |
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
|
56 |
with torch.no_grad():
|
57 |
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(
|
|
|
2 |
from preprocess import Preprocessor
|
3 |
import ocr
|
4 |
from PIL import Image
|
5 |
+
from transformers import LiltForTokenClassification, AutoTokenizer
|
6 |
import token_classification
|
7 |
import torch
|
8 |
from fastapi import FastAPI, UploadFile
|
|
|
19 |
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
|
21 |
config['processor'] = Preprocessor(settings.TOKENIZER)
|
22 |
+
config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
|
23 |
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
|
24 |
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
|
25 |
yield
|
|
|
30 |
|
31 |
@app.post("/submit-doc")
|
32 |
async def ProcessDocument(file: UploadFile):
|
33 |
+
tokenClassificationOutput, ocr_df, img_size = await LabelTokens(file)
|
34 |
+
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
35 |
return reOutput
|
36 |
|
37 |
async def LabelTokens(file):
|
|
|
40 |
ocr_df = config['vision_client'].ocr(content, image)
|
41 |
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
42 |
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
|
43 |
+
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, ocr_df, image.size
|
44 |
|
45 |
+
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
46 |
token_labels = tokenClassificationOutput['token_labels']
|
47 |
input_ids = tokenClassificationOutput['input_ids']
|
|
|
48 |
attention_mask = tokenClassificationOutput["attention_mask"]
|
49 |
+
bbox_org = tokenClassificationOutput["bbox"]
|
50 |
|
51 |
+
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
|
52 |
|
53 |
+
entities = merged_output['entities']
|
54 |
+
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
|
55 |
+
bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
|
56 |
+
attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])
|
57 |
+
|
58 |
+
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
|
59 |
+
decoded_entities = []
|
60 |
+
for entity in entities:
|
61 |
+
decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
|
62 |
+
entity['label'] = id2label[entity['label']]
|
63 |
+
|
64 |
config['re_model'].to(config['device'])
|
65 |
+
entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
|
66 |
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
|
67 |
with torch.no_grad():
|
68 |
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
|
69 |
|
70 |
+
decoded_pred_relations = []
|
71 |
+
for relation in outputs.pred_relations[0]:
|
72 |
+
head_start, head_end = relation['head']
|
73 |
+
tail_start, tail_end = relation['tail']
|
74 |
+
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
|
75 |
+
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
|
76 |
+
decoded_pred_relations.append((question, answer))
|
77 |
+
# print("Question:", question)
|
78 |
+
# print("Answer:", answer)
|
79 |
+
## This prints bboxes of each question and answer
|
80 |
+
# for item in merged_words:
|
81 |
+
# if item['text'] == question:
|
82 |
+
# print('Question', item['box'])
|
83 |
+
# if item['text'] == answer:
|
84 |
+
# print('Answer', item['box'])
|
85 |
+
# print("----------")
|
86 |
|
87 |
+
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}
|
models.py
CHANGED
@@ -196,6 +196,8 @@ class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel):
|
|
196 |
super().__init__(config)
|
197 |
|
198 |
self.lilt = LiltModel(config, add_pooling_layer=False)
|
|
|
|
|
199 |
self.rehead = REHead(config)
|
200 |
self.init_weights()
|
201 |
|
@@ -216,6 +218,8 @@ class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel):
|
|
216 |
entities=None,
|
217 |
relations=None,
|
218 |
):
|
|
|
|
|
219 |
|
220 |
outputs = self.lilt(
|
221 |
input_ids,
|
@@ -230,7 +234,8 @@ class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel):
|
|
230 |
return_dict=return_dict,
|
231 |
)
|
232 |
|
|
|
233 |
sequence_output = outputs[0]
|
234 |
-
|
235 |
re_output = self.rehead(sequence_output, entities, relations)
|
236 |
return re_output
|
|
|
196 |
super().__init__(config)
|
197 |
|
198 |
self.lilt = LiltModel(config, add_pooling_layer=False)
|
199 |
+
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
200 |
+
# self.extractor = REDecoder(config, config.hidden_size)
|
201 |
self.rehead = REHead(config)
|
202 |
self.init_weights()
|
203 |
|
|
|
218 |
entities=None,
|
219 |
relations=None,
|
220 |
):
|
221 |
+
# for param in self.lilt.parameters():
|
222 |
+
# param.requires_grad = False
|
223 |
|
224 |
outputs = self.lilt(
|
225 |
input_ids,
|
|
|
234 |
return_dict=return_dict,
|
235 |
)
|
236 |
|
237 |
+
seq_length = input_ids.size(1)
|
238 |
sequence_output = outputs[0]
|
239 |
+
|
240 |
re_output = self.rehead(sequence_output, entities, relations)
|
241 |
return re_output
|
preprocess.py
CHANGED
@@ -1,6 +1,14 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# class to turn the keys of a dict into attributes (thanks Stackoverflow)
|
5 |
class AttrDict(dict):
|
6 |
def __init__(self, *args, **kwargs):
|
@@ -23,14 +31,6 @@ class Preprocessor():
|
|
23 |
actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
|
24 |
actual_boxes.append(actual_box)
|
25 |
|
26 |
-
def normalize_box(box, width, height):
|
27 |
-
return [
|
28 |
-
int(1000 * (box[0] / width)),
|
29 |
-
int(1000 * (box[1] / height)),
|
30 |
-
int(1000 * (box[2] / width)),
|
31 |
-
int(1000 * (box[3] / height)),
|
32 |
-
]
|
33 |
-
|
34 |
boxes = []
|
35 |
for box in actual_boxes:
|
36 |
boxes.append(normalize_box(box, width, height))
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer
|
3 |
|
4 |
+
def normalize_box(box, width, height):
|
5 |
+
return [
|
6 |
+
int(1000 * (box[0] / width)),
|
7 |
+
int(1000 * (box[1] / height)),
|
8 |
+
int(1000 * (box[2] / width)),
|
9 |
+
int(1000 * (box[3] / height)),
|
10 |
+
]
|
11 |
+
|
12 |
# class to turn the keys of a dict into attributes (thanks Stackoverflow)
|
13 |
class AttrDict(dict):
|
14 |
def __init__(self, *args, **kwargs):
|
|
|
31 |
actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
|
32 |
actual_boxes.append(actual_box)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
boxes = []
|
35 |
for box in actual_boxes:
|
36 |
boxes.append(normalize_box(box, width, height))
|
token_classification.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import numpy as np
|
|
|
|
|
2 |
|
3 |
def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
|
4 |
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
|
@@ -6,31 +8,213 @@ def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
|
|
6 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
7 |
return predictions
|
8 |
|
9 |
-
def
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
from preprocess import normalize_box
|
3 |
+
import copy
|
4 |
|
5 |
def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
|
6 |
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
|
|
|
8 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
9 |
return predictions
|
10 |
|
11 |
+
def compare_boxes(b1,b2):
|
12 |
+
b1 = np.array([c for c in b1])
|
13 |
+
b2 = np.array([c for c in b2])
|
14 |
+
equal = np.array_equal(b1,b2)
|
15 |
+
return equal
|
16 |
+
|
17 |
+
def mergable(w1,w2):
|
18 |
+
if w1['label'] == w2['label']:
|
19 |
+
threshold = 7
|
20 |
+
if abs(w1['box'][1] - w2['box'][1]) < threshold or abs(w1['box'][-1] - w2['box'][-1]) < threshold:
|
21 |
+
return True
|
22 |
+
return False
|
23 |
+
return False
|
24 |
+
|
25 |
+
def convert_data(data, tokenizer, img_size):
|
26 |
+
def normalize_bbox(bbox, size):
|
27 |
+
return [
|
28 |
+
int(1000 * bbox[0] / size[0]),
|
29 |
+
int(1000 * bbox[1] / size[1]),
|
30 |
+
int(1000 * bbox[2] / size[0]),
|
31 |
+
int(1000 * bbox[3] / size[1]),
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def simplify_bbox(bbox):
|
36 |
+
return [
|
37 |
+
min(bbox[0::2]),
|
38 |
+
min(bbox[1::2]),
|
39 |
+
max(bbox[2::2]),
|
40 |
+
max(bbox[3::2]),
|
41 |
+
]
|
42 |
+
|
43 |
+
def merge_bbox(bbox_list):
|
44 |
+
x0, y0, x1, y1 = list(zip(*bbox_list))
|
45 |
+
return [min(x0), min(y0), max(x1), max(y1)]
|
46 |
+
|
47 |
+
tokenized_doc = {"input_ids": [], "bbox": [], "labels": [], "attention_mask":[]}
|
48 |
+
entities = []
|
49 |
+
id2label = {}
|
50 |
+
entity_id_to_index_map = {}
|
51 |
+
empty_entity = set()
|
52 |
+
for line in data:
|
53 |
+
if len(line["text"]) == 0:
|
54 |
+
empty_entity.add(line["id"])
|
55 |
+
continue
|
56 |
+
id2label[line["id"]] = line["label"]
|
57 |
+
tokenized_inputs = tokenizer(
|
58 |
+
line["text"],
|
59 |
+
add_special_tokens=False,
|
60 |
+
return_offsets_mapping=True,
|
61 |
+
return_attention_mask=True,
|
62 |
+
)
|
63 |
+
text_length = 0
|
64 |
+
ocr_length = 0
|
65 |
+
bbox = []
|
66 |
+
for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]):
|
67 |
+
if token_id == 6:
|
68 |
+
bbox.append(None)
|
69 |
+
continue
|
70 |
+
text_length += offset[1] - offset[0]
|
71 |
+
tmp_box = []
|
72 |
+
while ocr_length < text_length:
|
73 |
+
ocr_word = line["words"].pop(0)
|
74 |
+
ocr_length += len(
|
75 |
+
tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip())
|
76 |
+
)
|
77 |
+
tmp_box.append(simplify_bbox(ocr_word["box"]))
|
78 |
+
if len(tmp_box) == 0:
|
79 |
+
tmp_box = last_box
|
80 |
+
bbox.append(normalize_bbox(merge_bbox(tmp_box), img_size))
|
81 |
+
last_box = tmp_box # noqa
|
82 |
+
bbox = [
|
83 |
+
[bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b
|
84 |
+
for i, b in enumerate(bbox)
|
85 |
+
]
|
86 |
+
if line["label"] == "other":
|
87 |
+
label = ["O"] * len(bbox)
|
88 |
+
else:
|
89 |
+
label = [f"I-{line['label'].upper()}"] * len(bbox)
|
90 |
+
label[0] = f"B-{line['label'].upper()}"
|
91 |
+
tokenized_inputs.update({"bbox": bbox, "labels": label})
|
92 |
+
if label[0] != "O":
|
93 |
+
entity_id_to_index_map[line["id"]] = len(entities)
|
94 |
+
entities.append(
|
95 |
+
{
|
96 |
+
"start": len(tokenized_doc["input_ids"]),
|
97 |
+
"end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]),
|
98 |
+
"label": line["label"].upper(),
|
99 |
+
}
|
100 |
+
)
|
101 |
+
for i in tokenized_doc:
|
102 |
+
tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i]
|
103 |
+
|
104 |
+
chunk_size = 512
|
105 |
+
output = {}
|
106 |
+
for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)):
|
107 |
+
item = {}
|
108 |
+
entities_in_this_span = []
|
109 |
+
for k in tokenized_doc:
|
110 |
+
item[k] = tokenized_doc[k][index : index + chunk_size]
|
111 |
+
global_to_local_map = {}
|
112 |
+
for entity_id, entity in enumerate(entities):
|
113 |
+
if (
|
114 |
+
index <= entity["start"] < index + chunk_size
|
115 |
+
and index <= entity["end"] < index + chunk_size
|
116 |
+
):
|
117 |
+
entity["start"] = entity["start"] - index
|
118 |
+
entity["end"] = entity["end"] - index
|
119 |
+
global_to_local_map[entity_id] = len(entities_in_this_span)
|
120 |
+
entities_in_this_span.append(entity)
|
121 |
+
item.update(
|
122 |
+
{
|
123 |
+
"entities": entities_in_this_span
|
124 |
+
}
|
125 |
+
)
|
126 |
+
for key in item.keys():
|
127 |
+
output[key] = output.get(key, []) + item[key]
|
128 |
+
return output
|
129 |
+
|
130 |
+
def dfs(i, merged, width, height, visited, df_words):
|
131 |
+
v_threshold = int(.01 * height)
|
132 |
+
h_threshold = int(.08 * width)
|
133 |
+
visited.add(i)
|
134 |
+
merged.append(df_words[i])
|
135 |
+
|
136 |
+
for j in range(len(df_words)):
|
137 |
+
if j not in visited:
|
138 |
+
w1 = df_words[i]['words'][0]
|
139 |
+
w2 = df_words[j]['words'][0]
|
140 |
+
|
141 |
+
# and
|
142 |
+
if (abs(w1['box'][1] - w2['box'][1]) < v_threshold or abs(w1['box'][-1] - w2['box'][-1]) < v_threshold) \
|
143 |
+
and (df_words[i]['label'] == df_words[j]['label']) \
|
144 |
+
and (abs(w1['box'][0] - w2['box'][0]) < h_threshold or abs(w1['box'][-2] - w2['box'][-2]) < h_threshold):
|
145 |
+
dfs(j,merged, width, height, visited, df_words)
|
146 |
+
return merged
|
147 |
+
|
148 |
+
def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, bbox):
|
149 |
+
width, height = img_size
|
150 |
+
words = []
|
151 |
+
for index,row in ocr_df.iterrows():
|
152 |
+
word = {}
|
153 |
+
origin_box = [row['left'],row['top'],row['left']+row['width'],row['top']+row['height']]
|
154 |
+
word['word_text'] = row['text']
|
155 |
+
word['word_box'] = origin_box
|
156 |
+
word['normalized_box'] = normalize_box(word['word_box'], width, height)
|
157 |
+
words.append(word)
|
158 |
+
|
159 |
+
raw_input_ids = input_ids[0].tolist()
|
160 |
+
token_boxes = bbox.squeeze().tolist()
|
161 |
+
special_tokens = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]
|
162 |
+
|
163 |
+
input_ids = [id for id in raw_input_ids if id not in special_tokens]
|
164 |
+
predictions = [model.config.id2label[prediction] for i,prediction in enumerate(predictions) if not (raw_input_ids[i] in special_tokens)]
|
165 |
+
actual_boxes = [box for i,box in enumerate(token_boxes) if not (raw_input_ids[i] in special_tokens )]
|
166 |
+
|
167 |
+
assert(len(actual_boxes) == len(predictions))
|
168 |
+
|
169 |
+
for word in words:
|
170 |
+
word_labels = []
|
171 |
+
token_labels = []
|
172 |
+
word_tagging = None
|
173 |
+
for i,box in enumerate(actual_boxes,start=0):
|
174 |
+
if compare_boxes(word['normalized_box'],box):
|
175 |
+
if predictions[i] != 'O':
|
176 |
+
word_labels.append(predictions[i][2:])
|
177 |
+
else:
|
178 |
+
word_labels.append('O')
|
179 |
+
token_labels.append(predictions[i])
|
180 |
+
if word_labels != []:
|
181 |
+
word_tagging = word_labels[0] if word_labels[0] != 'O' else word_labels[-1]
|
182 |
+
else:
|
183 |
+
word_tagging = 'O'
|
184 |
+
word['word_labels'] = token_labels
|
185 |
+
word['word_tagging'] = word_tagging
|
186 |
+
|
187 |
+
filtered_words = [{'id':i,'text':word['word_text'],
|
188 |
+
'label':word['word_tagging'],
|
189 |
+
'box':word['word_box'],
|
190 |
+
'words':[{'box':word['word_box'],'text':word['word_text']}]} for i,word in enumerate(words) if word['word_tagging'] != 'O']
|
191 |
+
|
192 |
+
merged_taggings = []
|
193 |
+
df_words = filtered_words.copy()
|
194 |
+
visited = set()
|
195 |
+
for i in range(len(df_words)):
|
196 |
+
if i not in visited:
|
197 |
+
merged_taggings.append(dfs(i,[], width, height, visited, df_words))
|
198 |
+
|
199 |
+
merged_words = []
|
200 |
+
for i,merged_tagging in enumerate(merged_taggings):
|
201 |
+
if len(merged_tagging) > 1:
|
202 |
+
new_word = {}
|
203 |
+
merging_word = " ".join([word['text'] for word in merged_tagging])
|
204 |
+
merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
|
205 |
+
new_word['text'] = merging_word
|
206 |
+
new_word['box'] = merging_box
|
207 |
+
new_word['label'] = merged_tagging[0]['label']
|
208 |
+
new_word['id'] = filtered_words[-1]['id']+i+1
|
209 |
+
new_word['words'] = [{'box':word['box'],'text':word['text']} for word in merged_tagging]
|
210 |
+
# new_word['start'] =
|
211 |
+
merged_words.append(new_word)
|
212 |
+
|
213 |
+
filtered_words.extend(merged_words)
|
214 |
+
predictions = [word['label'] for word in filtered_words]
|
215 |
+
actual_boxes = [word['box'] for word in filtered_words]
|
216 |
+
unique_taggings = set(predictions)
|
217 |
+
|
218 |
+
output = convert_data(copy.deepcopy(merged_words), tokenizer, img_size)
|
219 |
+
return output, merged_words
|
220 |
+
|