Shaltiel commited on
Commit
e68d30b
1 Parent(s): 6235062

Update BertForJointParsing.py

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +4 -0
BertForJointParsing.py CHANGED
@@ -273,6 +273,8 @@ def combine_token_wordpieces(input_ids: List[int], offset_mapping: torch.Tensor,
273
  ret = []
274
  special_toks = tokenizer.all_special_tokens
275
  special_toks.remove(tokenizer.unk_token)
 
 
276
  for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
277
  if token in special_toks: continue
278
  if token.startswith('##'):
@@ -287,6 +289,8 @@ def ner_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer
287
 
288
  special_toks = tokenizer.all_special_tokens
289
  special_toks.remove(tokenizer.unk_token)
 
 
290
  for batch_idx in range(len(sentences)):
291
 
292
  ret = []
 
273
  ret = []
274
  special_toks = tokenizer.all_special_tokens
275
  special_toks.remove(tokenizer.unk_token)
276
+ special_toks.remove(tokenizer.mask_token)
277
+
278
  for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
279
  if token in special_toks: continue
280
  if token.startswith('##'):
 
289
 
290
  special_toks = tokenizer.all_special_tokens
291
  special_toks.remove(tokenizer.unk_token)
292
+ special_toks.remove(tokenizer.mask_token)
293
+
294
  for batch_idx in range(len(sentences)):
295
 
296
  ret = []