kavg commited on
Commit
0dd8e27
1 Parent(s): 1ddbc38

added ned entity merging method. Included additional outputs in respons to match frontend streamlit app

Browse files
Files changed (4) hide show
  1. main.py +37 -16
  2. models.py +6 -1
  3. preprocess.py +8 -8
  4. token_classification.py +212 -28
main.py CHANGED
@@ -2,7 +2,7 @@ from config import Settings
2
  from preprocess import Preprocessor
3
  import ocr
4
  from PIL import Image
5
- from transformers import LiltForTokenClassification
6
  import token_classification
7
  import torch
8
  from fastapi import FastAPI, UploadFile
@@ -19,6 +19,7 @@ async def lifespan(app: FastAPI):
19
  config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
21
  config['processor'] = Preprocessor(settings.TOKENIZER)
 
22
  config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
23
  config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
24
  yield
@@ -29,8 +30,8 @@ app = FastAPI(lifespan=lifespan)
29
 
30
  @app.post("/submit-doc")
31
  async def ProcessDocument(file: UploadFile):
32
- tokenClassificationOutput = await LabelTokens(file)
33
- reOutput = ExtractRelations(tokenClassificationOutput)
34
  return reOutput
35
 
36
  async def LabelTokens(file):
@@ -39,28 +40,48 @@ async def LabelTokens(file):
39
  ocr_df = config['vision_client'].ocr(content, image)
40
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
41
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
42
- return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "offset_mapping":offset_mapping, "attention_mask":attention_mask}
43
 
44
- def ExtractRelations(tokenClassificationOutput):
45
  token_labels = tokenClassificationOutput['token_labels']
46
  input_ids = tokenClassificationOutput['input_ids']
47
- offset_mapping = tokenClassificationOutput["offset_mapping"]
48
  attention_mask = tokenClassificationOutput["attention_mask"]
49
- bbox = tokenClassificationOutput["bbox"]
50
 
51
- entities = token_classification.createEntities(config['ser_model'], token_labels, input_ids, offset_mapping)
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  config['re_model'].to(config['device'])
54
- entity_dict = {'start': [entity[0] for entity in entities], 'end': [entity[1] for entity in entities], 'label': [entity[3] for entity in entities]}
55
  relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
56
  with torch.no_grad():
57
  outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
58
 
59
- print(type(outputs.pred_relations[0]))
60
- print(type(entities))
61
- print(type(input_ids))
62
- print(type(bbox))
63
- print(type(token_labels))
64
- # "pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()),
 
 
 
 
 
 
 
 
 
 
65
 
66
- return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox.tolist()),"token_labels":json.dumps(token_labels)}
 
2
  from preprocess import Preprocessor
3
  import ocr
4
  from PIL import Image
5
+ from transformers import LiltForTokenClassification, AutoTokenizer
6
  import token_classification
7
  import torch
8
  from fastapi import FastAPI, UploadFile
 
19
  config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
21
  config['processor'] = Preprocessor(settings.TOKENIZER)
22
+ config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
23
  config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
24
  config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
25
  yield
 
30
 
31
  @app.post("/submit-doc")
32
  async def ProcessDocument(file: UploadFile):
33
+ tokenClassificationOutput, ocr_df, img_size = await LabelTokens(file)
34
+ reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
35
  return reOutput
36
 
37
  async def LabelTokens(file):
 
40
  ocr_df = config['vision_client'].ocr(content, image)
41
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
42
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
43
+ return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, ocr_df, image.size
44
 
45
+ def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
46
  token_labels = tokenClassificationOutput['token_labels']
47
  input_ids = tokenClassificationOutput['input_ids']
 
48
  attention_mask = tokenClassificationOutput["attention_mask"]
49
+ bbox_org = tokenClassificationOutput["bbox"]
50
 
51
+ merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
52
 
53
+ entities = merged_output['entities']
54
+ input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
55
+ bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
56
+ attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])
57
+
58
+ id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
59
+ decoded_entities = []
60
+ for entity in entities:
61
+ decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
62
+ entity['label'] = id2label[entity['label']]
63
+
64
  config['re_model'].to(config['device'])
65
+ entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
66
  relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
67
  with torch.no_grad():
68
  outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
69
 
70
+ decoded_pred_relations = []
71
+ for relation in outputs.pred_relations[0]:
72
+ head_start, head_end = relation['head']
73
+ tail_start, tail_end = relation['tail']
74
+ question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
75
+ answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
76
+ decoded_pred_relations.append((question, answer))
77
+ # print("Question:", question)
78
+ # print("Answer:", answer)
79
+ ## This prints bboxes of each question and answer
80
+ # for item in merged_words:
81
+ # if item['text'] == question:
82
+ # print('Question', item['box'])
83
+ # if item['text'] == answer:
84
+ # print('Answer', item['box'])
85
+ # print("----------")
86
 
87
+ return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}
models.py CHANGED
@@ -196,6 +196,8 @@ class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel):
196
  super().__init__(config)
197
 
198
  self.lilt = LiltModel(config, add_pooling_layer=False)
 
 
199
  self.rehead = REHead(config)
200
  self.init_weights()
201
 
@@ -216,6 +218,8 @@ class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel):
216
  entities=None,
217
  relations=None,
218
  ):
 
 
219
 
220
  outputs = self.lilt(
221
  input_ids,
@@ -230,7 +234,8 @@ class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel):
230
  return_dict=return_dict,
231
  )
232
 
 
233
  sequence_output = outputs[0]
234
-
235
  re_output = self.rehead(sequence_output, entities, relations)
236
  return re_output
 
196
  super().__init__(config)
197
 
198
  self.lilt = LiltModel(config, add_pooling_layer=False)
199
+ # self.dropout = nn.Dropout(config.hidden_dropout_prob)
200
+ # self.extractor = REDecoder(config, config.hidden_size)
201
  self.rehead = REHead(config)
202
  self.init_weights()
203
 
 
218
  entities=None,
219
  relations=None,
220
  ):
221
+ # for param in self.lilt.parameters():
222
+ # param.requires_grad = False
223
 
224
  outputs = self.lilt(
225
  input_ids,
 
234
  return_dict=return_dict,
235
  )
236
 
237
+ seq_length = input_ids.size(1)
238
  sequence_output = outputs[0]
239
+
240
  re_output = self.rehead(sequence_output, entities, relations)
241
  return re_output
preprocess.py CHANGED
@@ -1,6 +1,14 @@
1
  import torch
2
  from transformers import AutoTokenizer
3
 
 
 
 
 
 
 
 
 
4
  # class to turn the keys of a dict into attributes (thanks Stackoverflow)
5
  class AttrDict(dict):
6
  def __init__(self, *args, **kwargs):
@@ -23,14 +31,6 @@ class Preprocessor():
23
  actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
24
  actual_boxes.append(actual_box)
25
 
26
- def normalize_box(box, width, height):
27
- return [
28
- int(1000 * (box[0] / width)),
29
- int(1000 * (box[1] / height)),
30
- int(1000 * (box[2] / width)),
31
- int(1000 * (box[3] / height)),
32
- ]
33
-
34
  boxes = []
35
  for box in actual_boxes:
36
  boxes.append(normalize_box(box, width, height))
 
1
  import torch
2
  from transformers import AutoTokenizer
3
 
4
+ def normalize_box(box, width, height):
5
+ return [
6
+ int(1000 * (box[0] / width)),
7
+ int(1000 * (box[1] / height)),
8
+ int(1000 * (box[2] / width)),
9
+ int(1000 * (box[3] / height)),
10
+ ]
11
+
12
  # class to turn the keys of a dict into attributes (thanks Stackoverflow)
13
  class AttrDict(dict):
14
  def __init__(self, *args, **kwargs):
 
31
  actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
32
  actual_boxes.append(actual_box)
33
 
 
 
 
 
 
 
 
 
34
  boxes = []
35
  for box in actual_boxes:
36
  boxes.append(normalize_box(box, width, height))
token_classification.py CHANGED
@@ -1,4 +1,6 @@
1
  import numpy as np
 
 
2
 
3
  def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
4
  outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
@@ -6,31 +8,213 @@ def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
6
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
7
  return predictions
8
 
9
- def createEntities(model, predictions, input_ids, offset_mapping):
10
- # we're only interested in tokens which aren't subwords
11
- # we'll use the offset mapping for that
12
- offset_mapping = np.array(offset_mapping)
13
- is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
14
-
15
- id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
16
-
17
- # finally, store recognized "question" and "answer" entities in a list
18
- entities = []
19
- current_entity = None
20
- start = None
21
- end = None
22
-
23
- for idx, (id, pred) in enumerate(zip(input_ids[0].tolist(), predictions)):
24
- if not is_subword[idx]:
25
- predicted_label = model.config.id2label[pred]
26
- if predicted_label.startswith("B") and current_entity is None:
27
- # means we're at the start of a new entity
28
- current_entity = predicted_label.replace("B-", "")
29
- start = idx
30
- if current_entity is not None and current_entity not in predicted_label:
31
- # means we're at the end of a new entity
32
- end = idx
33
- entities.append((start, end, current_entity, id2label[current_entity]))
34
- current_entity = None
35
-
36
- return entities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ from preprocess import normalize_box
3
+ import copy
4
 
5
  def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
6
  outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
 
8
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
9
  return predictions
10
 
11
+ def compare_boxes(b1,b2):
12
+ b1 = np.array([c for c in b1])
13
+ b2 = np.array([c for c in b2])
14
+ equal = np.array_equal(b1,b2)
15
+ return equal
16
+
17
+ def mergable(w1,w2):
18
+ if w1['label'] == w2['label']:
19
+ threshold = 7
20
+ if abs(w1['box'][1] - w2['box'][1]) < threshold or abs(w1['box'][-1] - w2['box'][-1]) < threshold:
21
+ return True
22
+ return False
23
+ return False
24
+
25
+ def convert_data(data, tokenizer, img_size):
26
+ def normalize_bbox(bbox, size):
27
+ return [
28
+ int(1000 * bbox[0] / size[0]),
29
+ int(1000 * bbox[1] / size[1]),
30
+ int(1000 * bbox[2] / size[0]),
31
+ int(1000 * bbox[3] / size[1]),
32
+ ]
33
+
34
+
35
+ def simplify_bbox(bbox):
36
+ return [
37
+ min(bbox[0::2]),
38
+ min(bbox[1::2]),
39
+ max(bbox[2::2]),
40
+ max(bbox[3::2]),
41
+ ]
42
+
43
+ def merge_bbox(bbox_list):
44
+ x0, y0, x1, y1 = list(zip(*bbox_list))
45
+ return [min(x0), min(y0), max(x1), max(y1)]
46
+
47
+ tokenized_doc = {"input_ids": [], "bbox": [], "labels": [], "attention_mask":[]}
48
+ entities = []
49
+ id2label = {}
50
+ entity_id_to_index_map = {}
51
+ empty_entity = set()
52
+ for line in data:
53
+ if len(line["text"]) == 0:
54
+ empty_entity.add(line["id"])
55
+ continue
56
+ id2label[line["id"]] = line["label"]
57
+ tokenized_inputs = tokenizer(
58
+ line["text"],
59
+ add_special_tokens=False,
60
+ return_offsets_mapping=True,
61
+ return_attention_mask=True,
62
+ )
63
+ text_length = 0
64
+ ocr_length = 0
65
+ bbox = []
66
+ for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]):
67
+ if token_id == 6:
68
+ bbox.append(None)
69
+ continue
70
+ text_length += offset[1] - offset[0]
71
+ tmp_box = []
72
+ while ocr_length < text_length:
73
+ ocr_word = line["words"].pop(0)
74
+ ocr_length += len(
75
+ tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip())
76
+ )
77
+ tmp_box.append(simplify_bbox(ocr_word["box"]))
78
+ if len(tmp_box) == 0:
79
+ tmp_box = last_box
80
+ bbox.append(normalize_bbox(merge_bbox(tmp_box), img_size))
81
+ last_box = tmp_box # noqa
82
+ bbox = [
83
+ [bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b
84
+ for i, b in enumerate(bbox)
85
+ ]
86
+ if line["label"] == "other":
87
+ label = ["O"] * len(bbox)
88
+ else:
89
+ label = [f"I-{line['label'].upper()}"] * len(bbox)
90
+ label[0] = f"B-{line['label'].upper()}"
91
+ tokenized_inputs.update({"bbox": bbox, "labels": label})
92
+ if label[0] != "O":
93
+ entity_id_to_index_map[line["id"]] = len(entities)
94
+ entities.append(
95
+ {
96
+ "start": len(tokenized_doc["input_ids"]),
97
+ "end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]),
98
+ "label": line["label"].upper(),
99
+ }
100
+ )
101
+ for i in tokenized_doc:
102
+ tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i]
103
+
104
+ chunk_size = 512
105
+ output = {}
106
+ for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)):
107
+ item = {}
108
+ entities_in_this_span = []
109
+ for k in tokenized_doc:
110
+ item[k] = tokenized_doc[k][index : index + chunk_size]
111
+ global_to_local_map = {}
112
+ for entity_id, entity in enumerate(entities):
113
+ if (
114
+ index <= entity["start"] < index + chunk_size
115
+ and index <= entity["end"] < index + chunk_size
116
+ ):
117
+ entity["start"] = entity["start"] - index
118
+ entity["end"] = entity["end"] - index
119
+ global_to_local_map[entity_id] = len(entities_in_this_span)
120
+ entities_in_this_span.append(entity)
121
+ item.update(
122
+ {
123
+ "entities": entities_in_this_span
124
+ }
125
+ )
126
+ for key in item.keys():
127
+ output[key] = output.get(key, []) + item[key]
128
+ return output
129
+
130
+ def dfs(i, merged, width, height, visited, df_words):
131
+ v_threshold = int(.01 * height)
132
+ h_threshold = int(.08 * width)
133
+ visited.add(i)
134
+ merged.append(df_words[i])
135
+
136
+ for j in range(len(df_words)):
137
+ if j not in visited:
138
+ w1 = df_words[i]['words'][0]
139
+ w2 = df_words[j]['words'][0]
140
+
141
+ # and
142
+ if (abs(w1['box'][1] - w2['box'][1]) < v_threshold or abs(w1['box'][-1] - w2['box'][-1]) < v_threshold) \
143
+ and (df_words[i]['label'] == df_words[j]['label']) \
144
+ and (abs(w1['box'][0] - w2['box'][0]) < h_threshold or abs(w1['box'][-2] - w2['box'][-2]) < h_threshold):
145
+ dfs(j,merged, width, height, visited, df_words)
146
+ return merged
147
+
148
+ def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, bbox):
149
+ width, height = img_size
150
+ words = []
151
+ for index,row in ocr_df.iterrows():
152
+ word = {}
153
+ origin_box = [row['left'],row['top'],row['left']+row['width'],row['top']+row['height']]
154
+ word['word_text'] = row['text']
155
+ word['word_box'] = origin_box
156
+ word['normalized_box'] = normalize_box(word['word_box'], width, height)
157
+ words.append(word)
158
+
159
+ raw_input_ids = input_ids[0].tolist()
160
+ token_boxes = bbox.squeeze().tolist()
161
+ special_tokens = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]
162
+
163
+ input_ids = [id for id in raw_input_ids if id not in special_tokens]
164
+ predictions = [model.config.id2label[prediction] for i,prediction in enumerate(predictions) if not (raw_input_ids[i] in special_tokens)]
165
+ actual_boxes = [box for i,box in enumerate(token_boxes) if not (raw_input_ids[i] in special_tokens )]
166
+
167
+ assert(len(actual_boxes) == len(predictions))
168
+
169
+ for word in words:
170
+ word_labels = []
171
+ token_labels = []
172
+ word_tagging = None
173
+ for i,box in enumerate(actual_boxes,start=0):
174
+ if compare_boxes(word['normalized_box'],box):
175
+ if predictions[i] != 'O':
176
+ word_labels.append(predictions[i][2:])
177
+ else:
178
+ word_labels.append('O')
179
+ token_labels.append(predictions[i])
180
+ if word_labels != []:
181
+ word_tagging = word_labels[0] if word_labels[0] != 'O' else word_labels[-1]
182
+ else:
183
+ word_tagging = 'O'
184
+ word['word_labels'] = token_labels
185
+ word['word_tagging'] = word_tagging
186
+
187
+ filtered_words = [{'id':i,'text':word['word_text'],
188
+ 'label':word['word_tagging'],
189
+ 'box':word['word_box'],
190
+ 'words':[{'box':word['word_box'],'text':word['word_text']}]} for i,word in enumerate(words) if word['word_tagging'] != 'O']
191
+
192
+ merged_taggings = []
193
+ df_words = filtered_words.copy()
194
+ visited = set()
195
+ for i in range(len(df_words)):
196
+ if i not in visited:
197
+ merged_taggings.append(dfs(i,[], width, height, visited, df_words))
198
+
199
+ merged_words = []
200
+ for i,merged_tagging in enumerate(merged_taggings):
201
+ if len(merged_tagging) > 1:
202
+ new_word = {}
203
+ merging_word = " ".join([word['text'] for word in merged_tagging])
204
+ merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
205
+ new_word['text'] = merging_word
206
+ new_word['box'] = merging_box
207
+ new_word['label'] = merged_tagging[0]['label']
208
+ new_word['id'] = filtered_words[-1]['id']+i+1
209
+ new_word['words'] = [{'box':word['box'],'text':word['text']} for word in merged_tagging]
210
+ # new_word['start'] =
211
+ merged_words.append(new_word)
212
+
213
+ filtered_words.extend(merged_words)
214
+ predictions = [word['label'] for word in filtered_words]
215
+ actual_boxes = [word['box'] for word in filtered_words]
216
+ unique_taggings = set(predictions)
217
+
218
+ output = convert_data(copy.deepcopy(merged_words), tokenizer, img_size)
219
+ return output, merged_words
220
+