Update BertForJointParsing.py
Browse files- BertForJointParsing.py +4 -0
BertForJointParsing.py
CHANGED
@@ -273,6 +273,8 @@ def combine_token_wordpieces(input_ids: List[int], offset_mapping: torch.Tensor,
|
|
273 |
ret = []
|
274 |
special_toks = tokenizer.all_special_tokens
|
275 |
special_toks.remove(tokenizer.unk_token)
|
|
|
|
|
276 |
for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
|
277 |
if token in special_toks: continue
|
278 |
if token.startswith('##'):
|
@@ -287,6 +289,8 @@ def ner_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer
|
|
287 |
|
288 |
special_toks = tokenizer.all_special_tokens
|
289 |
special_toks.remove(tokenizer.unk_token)
|
|
|
|
|
290 |
for batch_idx in range(len(sentences)):
|
291 |
|
292 |
ret = []
|
|
|
273 |
ret = []
|
274 |
special_toks = tokenizer.all_special_tokens
|
275 |
special_toks.remove(tokenizer.unk_token)
|
276 |
+
special_toks.remove(tokenizer.mask_token)
|
277 |
+
|
278 |
for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
|
279 |
if token in special_toks: continue
|
280 |
if token.startswith('##'):
|
|
|
289 |
|
290 |
special_toks = tokenizer.all_special_tokens
|
291 |
special_toks.remove(tokenizer.unk_token)
|
292 |
+
special_toks.remove(tokenizer.mask_token)
|
293 |
+
|
294 |
for batch_idx in range(len(sentences)):
|
295 |
|
296 |
ret = []
|