File size: 8,390 Bytes
a228fac 0dd8e27 a228fac 0dd8e27 ccb5ac8 0dd8e27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import numpy as np
from preprocess import normalize_box
import copy
def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
# take argmax on last dimension to get predicted class ID per token
predictions = outputs.logits.argmax(-1).squeeze().tolist()
return predictions
def compare_boxes(b1,b2):
b1 = np.array([c for c in b1])
b2 = np.array([c for c in b2])
equal = np.array_equal(b1,b2)
return equal
def mergable(w1,w2):
if w1['label'] == w2['label']:
threshold = 7
if abs(w1['box'][1] - w2['box'][1]) < threshold or abs(w1['box'][-1] - w2['box'][-1]) < threshold:
return True
return False
return False
def convert_data(data, tokenizer, img_size):
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def simplify_bbox(bbox):
return [
min(bbox[0::2]),
min(bbox[1::2]),
max(bbox[2::2]),
max(bbox[3::2]),
]
def merge_bbox(bbox_list):
x0, y0, x1, y1 = list(zip(*bbox_list))
return [min(x0), min(y0), max(x1), max(y1)]
tokenized_doc = {"input_ids": [], "bbox": [], "labels": [], "attention_mask":[]}
entities = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
for line in data:
if len(line["text"]) == 0:
empty_entity.add(line["id"])
continue
id2label[line["id"]] = line["label"]
tokenized_inputs = tokenizer(
line["text"],
add_special_tokens=False,
return_offsets_mapping=True,
return_attention_mask=True,
)
text_length = 0
ocr_length = 0
bbox = []
for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]):
if token_id == 6:
bbox.append(None)
continue
text_length += offset[1] - offset[0]
tmp_box = []
while ocr_length < text_length:
ocr_word = line["words"].pop(0)
ocr_length += len(
tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip())
)
tmp_box.append(simplify_bbox(ocr_word["box"]))
if len(tmp_box) == 0:
tmp_box = last_box
bbox.append(normalize_bbox(merge_bbox(tmp_box), img_size))
last_box = tmp_box # noqa
bbox = [
[bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b
for i, b in enumerate(bbox)
]
if line["label"] == "other":
label = ["O"] * len(bbox)
else:
label = [f"I-{line['label'].upper()}"] * len(bbox)
label[0] = f"B-{line['label'].upper()}"
tokenized_inputs.update({"bbox": bbox, "labels": label})
if label[0] != "O":
entity_id_to_index_map[line["id"]] = len(entities)
entities.append(
{
"start": len(tokenized_doc["input_ids"]),
"end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]),
"label": line["label"].upper(),
}
)
for i in tokenized_doc:
tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i]
chunk_size = 512
output = {}
for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)):
item = {}
entities_in_this_span = []
for k in tokenized_doc:
item[k] = tokenized_doc[k][index : index + chunk_size]
global_to_local_map = {}
for entity_id, entity in enumerate(entities):
if (
index <= entity["start"] < index + chunk_size
and index <= entity["end"] < index + chunk_size
):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
item.update(
{
"entities": entities_in_this_span
}
)
for key in item.keys():
output[key] = output.get(key, []) + item[key]
return output
def dfs(i, merged, width, height, visited, df_words):
v_threshold = int(.01 * height)
h_threshold = int(.08 * width)
visited.add(i)
merged.append(df_words[i])
for j in range(len(df_words)):
if j not in visited:
w1 = df_words[i]['words'][0]
w2 = df_words[j]['words'][0]
# and
if (abs(w1['box'][1] - w2['box'][1]) < v_threshold or abs(w1['box'][-1] - w2['box'][-1]) < v_threshold) \
and (df_words[i]['label'] == df_words[j]['label']) \
and (abs(w1['box'][0] - w2['box'][0]) < h_threshold or abs(w1['box'][-2] - w2['box'][-2]) < h_threshold):
dfs(j,merged, width, height, visited, df_words)
return merged
def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, bbox):
width, height = img_size
words = []
for index,row in ocr_df.iterrows():
word = {}
origin_box = [row['left'],row['top'],row['left']+row['width'],row['top']+row['height']]
word['word_text'] = row['text']
word['word_box'] = origin_box
word['normalized_box'] = normalize_box(word['word_box'], width, height)
words.append(word)
raw_input_ids = input_ids[0].tolist()
token_boxes = bbox.squeeze().tolist()
special_tokens = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]
input_ids = [id for id in raw_input_ids if id not in special_tokens]
predictions = [model.config.id2label[prediction] for i,prediction in enumerate(predictions) if not (raw_input_ids[i] in special_tokens)]
actual_boxes = [box for i,box in enumerate(token_boxes) if not (raw_input_ids[i] in special_tokens )]
assert(len(actual_boxes) == len(predictions))
for word in words:
word_labels = []
token_labels = []
word_tagging = None
for i,box in enumerate(actual_boxes,start=0):
if compare_boxes(word['normalized_box'],box):
if predictions[i] != 'O':
word_labels.append(predictions[i][2:])
else:
word_labels.append('O')
token_labels.append(predictions[i])
if word_labels != []:
word_tagging = word_labels[0] if word_labels[0] != 'O' else word_labels[-1]
else:
word_tagging = 'O'
word['word_labels'] = token_labels
word['word_tagging'] = word_tagging
filtered_words = [{'id':i,'text':word['word_text'],
'label':word['word_tagging'],
'box':word['word_box'],
'words':[{'box':word['word_box'],'text':word['word_text']}]} for i,word in enumerate(words) if word['word_tagging'] != 'O']
merged_taggings = []
df_words = filtered_words.copy()
visited = set()
for i in range(len(df_words)):
if i not in visited:
merged_taggings.append(dfs(i,[], width, height, visited, df_words))
merged_words = []
for i,merged_tagging in enumerate(merged_taggings):
if ((len(merged_tagging) > 1)) or (merged_tagging[0]['label'] == 'ANSWER'):
new_word = {}
merging_word = " ".join([word['text'] for word in merged_tagging])
merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
new_word['text'] = merging_word
new_word['box'] = merging_box
new_word['label'] = merged_tagging[0]['label']
new_word['id'] = filtered_words[-1]['id']+i+1
new_word['words'] = [{'box':word['box'],'text':word['text']} for word in merged_tagging]
# new_word['start'] =
merged_words.append(new_word)
filtered_words.extend(merged_words)
predictions = [word['label'] for word in filtered_words]
actual_boxes = [word['box'] for word in filtered_words]
unique_taggings = set(predictions)
output = convert_data(copy.deepcopy(merged_words), tokenizer, img_size)
return output, merged_words
|