|
import numpy as np |
|
from preprocess import normalize_box |
|
import copy |
|
|
|
def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping): |
|
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) |
|
|
|
predictions = outputs.logits.argmax(-1).squeeze().tolist() |
|
return predictions |
|
|
|
def compare_boxes(b1,b2): |
|
b1 = np.array([c for c in b1]) |
|
b2 = np.array([c for c in b2]) |
|
equal = np.array_equal(b1,b2) |
|
return equal |
|
|
|
def mergable(w1,w2): |
|
if w1['label'] == w2['label']: |
|
threshold = 7 |
|
if abs(w1['box'][1] - w2['box'][1]) < threshold or abs(w1['box'][-1] - w2['box'][-1]) < threshold: |
|
return True |
|
return False |
|
return False |
|
|
|
def convert_data(data, tokenizer, img_size): |
|
def normalize_bbox(bbox, size): |
|
return [ |
|
int(1000 * bbox[0] / size[0]), |
|
int(1000 * bbox[1] / size[1]), |
|
int(1000 * bbox[2] / size[0]), |
|
int(1000 * bbox[3] / size[1]), |
|
] |
|
|
|
|
|
def simplify_bbox(bbox): |
|
return [ |
|
min(bbox[0::2]), |
|
min(bbox[1::2]), |
|
max(bbox[2::2]), |
|
max(bbox[3::2]), |
|
] |
|
|
|
def merge_bbox(bbox_list): |
|
x0, y0, x1, y1 = list(zip(*bbox_list)) |
|
return [min(x0), min(y0), max(x1), max(y1)] |
|
|
|
tokenized_doc = {"input_ids": [], "bbox": [], "labels": [], "attention_mask":[]} |
|
entities = [] |
|
id2label = {} |
|
entity_id_to_index_map = {} |
|
empty_entity = set() |
|
for line in data: |
|
if len(line["text"]) == 0: |
|
empty_entity.add(line["id"]) |
|
continue |
|
id2label[line["id"]] = line["label"] |
|
tokenized_inputs = tokenizer( |
|
line["text"], |
|
add_special_tokens=False, |
|
return_offsets_mapping=True, |
|
return_attention_mask=True, |
|
) |
|
text_length = 0 |
|
ocr_length = 0 |
|
bbox = [] |
|
for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]): |
|
if token_id == 6: |
|
bbox.append(None) |
|
continue |
|
text_length += offset[1] - offset[0] |
|
tmp_box = [] |
|
while ocr_length < text_length: |
|
ocr_word = line["words"].pop(0) |
|
ocr_length += len( |
|
tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip()) |
|
) |
|
tmp_box.append(simplify_bbox(ocr_word["box"])) |
|
if len(tmp_box) == 0: |
|
tmp_box = last_box |
|
bbox.append(normalize_bbox(merge_bbox(tmp_box), img_size)) |
|
last_box = tmp_box |
|
bbox = [ |
|
[bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b |
|
for i, b in enumerate(bbox) |
|
] |
|
if line["label"] == "other": |
|
label = ["O"] * len(bbox) |
|
else: |
|
label = [f"I-{line['label'].upper()}"] * len(bbox) |
|
label[0] = f"B-{line['label'].upper()}" |
|
tokenized_inputs.update({"bbox": bbox, "labels": label}) |
|
if label[0] != "O": |
|
entity_id_to_index_map[line["id"]] = len(entities) |
|
entities.append( |
|
{ |
|
"start": len(tokenized_doc["input_ids"]), |
|
"end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]), |
|
"label": line["label"].upper(), |
|
} |
|
) |
|
for i in tokenized_doc: |
|
tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i] |
|
|
|
chunk_size = 512 |
|
output = {} |
|
for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)): |
|
item = {} |
|
entities_in_this_span = [] |
|
for k in tokenized_doc: |
|
item[k] = tokenized_doc[k][index : index + chunk_size] |
|
global_to_local_map = {} |
|
for entity_id, entity in enumerate(entities): |
|
if ( |
|
index <= entity["start"] < index + chunk_size |
|
and index <= entity["end"] < index + chunk_size |
|
): |
|
entity["start"] = entity["start"] - index |
|
entity["end"] = entity["end"] - index |
|
global_to_local_map[entity_id] = len(entities_in_this_span) |
|
entities_in_this_span.append(entity) |
|
item.update( |
|
{ |
|
"entities": entities_in_this_span |
|
} |
|
) |
|
for key in item.keys(): |
|
output[key] = output.get(key, []) + item[key] |
|
return output |
|
|
|
def dfs(i, merged, width, height, visited, df_words): |
|
v_threshold = int(.01 * height) |
|
h_threshold = int(.08 * width) |
|
visited.add(i) |
|
merged.append(df_words[i]) |
|
|
|
for j in range(len(df_words)): |
|
if j not in visited: |
|
w1 = df_words[i]['words'][0] |
|
w2 = df_words[j]['words'][0] |
|
|
|
|
|
if (abs(w1['box'][1] - w2['box'][1]) < v_threshold or abs(w1['box'][-1] - w2['box'][-1]) < v_threshold) \ |
|
and (df_words[i]['label'] == df_words[j]['label']) \ |
|
and (abs(w1['box'][0] - w2['box'][0]) < h_threshold or abs(w1['box'][-2] - w2['box'][-2]) < h_threshold): |
|
dfs(j,merged, width, height, visited, df_words) |
|
return merged |
|
|
|
def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, bbox): |
|
width, height = img_size |
|
words = [] |
|
for index,row in ocr_df.iterrows(): |
|
word = {} |
|
origin_box = [row['left'],row['top'],row['left']+row['width'],row['top']+row['height']] |
|
word['word_text'] = row['text'] |
|
word['word_box'] = origin_box |
|
word['normalized_box'] = normalize_box(word['word_box'], width, height) |
|
words.append(word) |
|
|
|
raw_input_ids = input_ids[0].tolist() |
|
token_boxes = bbox.squeeze().tolist() |
|
special_tokens = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id] |
|
|
|
input_ids = [id for id in raw_input_ids if id not in special_tokens] |
|
predictions = [model.config.id2label[prediction] for i,prediction in enumerate(predictions) if not (raw_input_ids[i] in special_tokens)] |
|
actual_boxes = [box for i,box in enumerate(token_boxes) if not (raw_input_ids[i] in special_tokens )] |
|
|
|
assert(len(actual_boxes) == len(predictions)) |
|
|
|
for word in words: |
|
word_labels = [] |
|
token_labels = [] |
|
word_tagging = None |
|
for i,box in enumerate(actual_boxes,start=0): |
|
if compare_boxes(word['normalized_box'],box): |
|
if predictions[i] != 'O': |
|
word_labels.append(predictions[i][2:]) |
|
else: |
|
word_labels.append('O') |
|
token_labels.append(predictions[i]) |
|
if word_labels != []: |
|
word_tagging = word_labels[0] if word_labels[0] != 'O' else word_labels[-1] |
|
else: |
|
word_tagging = 'O' |
|
word['word_labels'] = token_labels |
|
word['word_tagging'] = word_tagging |
|
|
|
filtered_words = [{'id':i,'text':word['word_text'], |
|
'label':word['word_tagging'], |
|
'box':word['word_box'], |
|
'words':[{'box':word['word_box'],'text':word['word_text']}]} for i,word in enumerate(words) if word['word_tagging'] != 'O'] |
|
|
|
merged_taggings = [] |
|
df_words = filtered_words.copy() |
|
visited = set() |
|
for i in range(len(df_words)): |
|
if i not in visited: |
|
merged_taggings.append(dfs(i,[], width, height, visited, df_words)) |
|
merged_words = [] |
|
for i,merged_tagging in enumerate(merged_taggings): |
|
if ((len(merged_tagging) > 1)) or (merged_tagging[0]['label'] == 'ANSWER'): |
|
new_word = {} |
|
merging_word = " ".join([word['text'] for word in merged_tagging]) |
|
merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10] |
|
new_word['text'] = merging_word |
|
new_word['box'] = merging_box |
|
new_word['label'] = merged_tagging[0]['label'] |
|
new_word['id'] = filtered_words[-1]['id']+i+1 |
|
new_word['words'] = [{'box':word['box'],'text':word['text']} for word in merged_tagging] |
|
|
|
merged_words.append(new_word) |
|
|
|
filtered_words.extend(merged_words) |
|
predictions = [word['label'] for word in filtered_words] |
|
actual_boxes = [word['box'] for word in filtered_words] |
|
unique_taggings = set(predictions) |
|
|
|
output = convert_data(copy.deepcopy(merged_words), tokenizer, img_size) |
|
return output, merged_words |
|
|
|
|