Upload BertForJointParsing.py
Browse files- BertForJointParsing.py +19 -18
BertForJointParsing.py
CHANGED
@@ -81,6 +81,7 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
81 |
|
82 |
def set_output_embeddings(self, new_embeddings):
|
83 |
if self.lex is not None:
|
|
|
84 |
self.cls.predictions.decoder = new_embeddings
|
85 |
|
86 |
def forward(
|
@@ -207,7 +208,7 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
207 |
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
208 |
output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
|
209 |
|
210 |
-
final_output = [dict(text=sentence, tokens=
|
211 |
# Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
|
212 |
if output.syntax_logits is not None:
|
213 |
for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
|
@@ -231,10 +232,10 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
231 |
|
232 |
# NER logits each sentence gets a list(tuple(word, ner))
|
233 |
if output.ner_logits is not None:
|
234 |
-
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label
|
235 |
if per_token_ner:
|
236 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
237 |
-
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
238 |
|
239 |
if output_style in ['ud', 'iahlt_ud']:
|
240 |
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
@@ -245,36 +246,39 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
245 |
|
246 |
|
247 |
|
248 |
-
def aggregate_ner_tokens(
|
249 |
entities = []
|
250 |
prev = None
|
251 |
-
for
|
252 |
# O does nothing
|
253 |
if pred == 'O': prev = None
|
254 |
# B- || I-entity != prev (different entity or none)
|
255 |
elif pred.startswith('B-') or pred[2:] != prev:
|
256 |
prev = pred[2:]
|
257 |
-
entities.append([[word], prev, start, end])
|
258 |
else:
|
259 |
entities[-1][0].append(word)
|
260 |
-
entities[-1][
|
|
|
261 |
|
262 |
-
return [dict(phrase=' '.join(words),
|
263 |
|
264 |
def merge_token_list(src, update, key):
|
265 |
for token_src, token_update in zip(src, update):
|
266 |
token_src[key] = token_update
|
267 |
|
268 |
-
def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
|
|
|
269 |
ret = []
|
270 |
-
for token in tokenizer.convert_ids_to_tokens(input_ids):
|
271 |
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
|
272 |
if token.startswith('##'):
|
273 |
-
ret[-1] += token[2:]
|
274 |
-
|
|
|
275 |
return ret
|
276 |
|
277 |
-
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]
|
278 |
input_ids = inputs['input_ids']
|
279 |
|
280 |
predictions = torch.argmax(logits, dim=-1)
|
@@ -289,16 +293,13 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
|
|
289 |
|
290 |
token = tokenizer._convert_id_to_token(token_id)
|
291 |
|
292 |
-
# get the offsets for this token
|
293 |
-
start_pos, end_pos = offset_mapping[batch_idx, tok_idx]
|
294 |
# wordpieces should just be appended to the previous word
|
295 |
# we modify the last token in ret
|
296 |
# by discarding the original end position and replacing it with the new token's end position
|
297 |
if token.startswith('##'):
|
298 |
-
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item())
|
299 |
continue
|
300 |
-
|
301 |
-
ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]
|
302 |
|
303 |
return batch_ret
|
304 |
|
|
|
81 |
|
82 |
def set_output_embeddings(self, new_embeddings):
|
83 |
if self.lex is not None:
|
84 |
+
|
85 |
self.cls.predictions.decoder = new_embeddings
|
86 |
|
87 |
def forward(
|
|
|
208 |
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
209 |
output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
|
210 |
|
211 |
+
final_output = [dict(text=sentence, tokens=combine_token_wordpieces(ids, offsets, tokenizer)) for sentence, ids, offsets in zip(sentences, inputs['input_ids'], offset_mapping)]
|
212 |
# Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
|
213 |
if output.syntax_logits is not None:
|
214 |
for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
|
|
|
232 |
|
233 |
# NER logits each sentence gets a list(tuple(word, ner))
|
234 |
if output.ner_logits is not None:
|
235 |
+
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)):
|
236 |
if per_token_ner:
|
237 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
238 |
+
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
|
239 |
|
240 |
if output_style in ['ud', 'iahlt_ud']:
|
241 |
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
|
|
246 |
|
247 |
|
248 |
|
249 |
+
def aggregate_ner_tokens(final_output, parsed):
|
250 |
entities = []
|
251 |
prev = None
|
252 |
+
for token_idx, (d, (word, pred)) in enumerate(zip(final_output['tokens'], parsed)):
|
253 |
# O does nothing
|
254 |
if pred == 'O': prev = None
|
255 |
# B- || I-entity != prev (different entity or none)
|
256 |
elif pred.startswith('B-') or pred[2:] != prev:
|
257 |
prev = pred[2:]
|
258 |
+
entities.append([[word], dict(label=prev, start=d['offsets']['start'], end=d['offsets']['end'], token_start=token_idx, token_end=token_idx)])
|
259 |
else:
|
260 |
entities[-1][0].append(word)
|
261 |
+
entities[-1][1]['end'] = d['offsets']['end']
|
262 |
+
entities[-1][1]['token_end'] = token_idx
|
263 |
|
264 |
+
return [dict(phrase=' '.join(words), **d) for words, d in entities]
|
265 |
|
266 |
def merge_token_list(src, update, key):
|
267 |
for token_src, token_update in zip(src, update):
|
268 |
token_src[key] = token_update
|
269 |
|
270 |
+
def combine_token_wordpieces(input_ids: torch.Tensor, offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
|
271 |
+
offset_mapping = offset_mapping.tolist()
|
272 |
ret = []
|
273 |
+
for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
|
274 |
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
|
275 |
if token.startswith('##'):
|
276 |
+
ret[-1]['token'] += token[2:]
|
277 |
+
ret[-1]['offsets']['end'] = offsets[1]
|
278 |
+
else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
|
279 |
return ret
|
280 |
|
281 |
+
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
|
282 |
input_ids = inputs['input_ids']
|
283 |
|
284 |
predictions = torch.argmax(logits, dim=-1)
|
|
|
293 |
|
294 |
token = tokenizer._convert_id_to_token(token_id)
|
295 |
|
|
|
|
|
296 |
# wordpieces should just be appended to the previous word
|
297 |
# we modify the last token in ret
|
298 |
# by discarding the original end position and replacing it with the new token's end position
|
299 |
if token.startswith('##'):
|
|
|
300 |
continue
|
301 |
+
# for each token, we append a tuple containing: token, label, start position, end position
|
302 |
+
ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]))
|
303 |
|
304 |
return batch_ret
|
305 |
|