import collections import numpy as np MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) def is_start_piece(piece): """Check if the current word piece is the starting piece (BERT).""" # When a word has been split into # WordPieces, the first token does not have any marker and any subsequence # tokens are prefixed with ##. So whenever we see the ## token, we # append it to the previous set of word indexes. return not piece.startswith("##") def create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, max_ngrams=3, do_whole_word_mask=True, favor_longer_ngram=False, do_permutation=False, geometric_dist=False, masking_style="bert", zh_tokenizer=None): """Creates the predictions for the masked LM objective. Note: Tokens here are vocab ids and not text tokens.""" ''' modified from Megatron-LM Args: tokens: 输入 vocab_id_list: 词表token_id_list vocab_id_to_token_dict: token_id到token字典 masked_lm_prob:mask概率 cls_id、sep_id、mask_id:特殊token max_predictions_per_seq:最大mask个数 np_rng:mask随机数 max_ngrams:最大词长度 do_whole_word_mask:是否做全词掩码 favor_longer_ngram:优先用长的词 do_permutation:是否打乱 geometric_dist:用np_rng.geometric做随机 masking_style:mask类型 zh_tokenizer:WWM的分词器,比如用jieba.lcut做分词之类的 ''' cand_indexes = [] # Note(mingdachen): We create a list for recording if the piece is # the starting piece of current token, where 1 means true, so that # on-the-fly whole word masking is possible. token_boundary = [0] * len(tokens) # 如果没有指定中文分词器,那就直接按##算 if zh_tokenizer is None: for (i, token) in enumerate(tokens): if token == cls_id or token == sep_id: token_boundary[i] = 1 continue # Whole Word Masking means that if we mask all of the wordpieces # corresponding to an original word. # # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. if (do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token])): cand_indexes[-1].append(i) else: cand_indexes.append([i]) if is_start_piece(vocab_id_to_token_dict[token]): token_boundary[i] = 1 else: # 如果指定了中文分词器,那就先用分词器分词,然后再进行判断 # 获取去掉CLS SEP的原始文本 raw_tokens = [] for t in tokens: if t != cls_id and t != sep_id: raw_tokens.append(t) raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens] # 分词然后获取每次字开头的最长词的长度 word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True)) word_length_dict = {} for w in word_list: if len(w) < 1: continue if w[0] not in word_length_dict: word_length_dict[w[0]] = len(w) elif word_length_dict[w[0]] < len(w): word_length_dict[w[0]] = len(w) i = 0 # 从词表里面检索 while i < len(tokens): token_id = tokens[i] token = vocab_id_to_token_dict[token_id] if len(token) == 0 or token_id == cls_id or token_id == sep_id: token_boundary[i] = 1 i += 1 continue word_max_length = 1 if token[0] in word_length_dict: word_max_length = word_length_dict[token[0]] j = 0 word = '' word_end = i+1 # 兼容以前##的形式,如果后面的词是##开头的,那么直接把后面的拼到前面当作一个词 old_style = False while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'): old_style = True word_end += 1 if not old_style: while j < word_max_length and i+j < len(tokens): cur_token = tokens[i+j] word += vocab_id_to_token_dict[cur_token] j += 1 if word in word_list: word_end = i+j cand_indexes.append([p for p in range(i, word_end)]) token_boundary[i] = 1 i = word_end output_tokens = list(tokens) masked_lm_positions = [] masked_lm_labels = [] if masked_lm_prob == 0: return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) if not geometric_dist: # Note(mingdachen): # By default, we set the probilities to favor shorter ngram sequences. pvals = 1. / np.arange(1, max_ngrams + 1) pvals /= pvals.sum(keepdims=True) if favor_longer_ngram: pvals = pvals[::-1] # 获取一个ngram的idx,对于每个word,记录他的ngram的word ngram_indexes = [] for idx in range(len(cand_indexes)): ngram_index = [] for n in ngrams: ngram_index.append(cand_indexes[idx:idx + n]) ngram_indexes.append(ngram_index) np_rng.shuffle(ngram_indexes) (masked_lms, masked_spans) = ([], []) covered_indexes = set() for cand_index_set in ngram_indexes: if len(masked_lms) >= num_to_predict: break if not cand_index_set: continue # Note(mingdachen): # Skip current piece if they are covered in lm masking or previous ngrams. for index_set in cand_index_set[0]: for index in index_set: if index in covered_indexes: continue if not geometric_dist: n = np_rng.choice(ngrams[:len(cand_index_set)], p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) else: # Sampling "n" from the geometric distribution and clipping it to # the max_ngrams. Using p=0.2 default from the SpanBERT paper # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) n = min(np_rng.geometric(0.2), max_ngrams) index_set = sum(cand_index_set[n - 1], []) n -= 1 # Note(mingdachen): # Repeatedly looking for a candidate that does not exceed the # maximum number of predictions by trying shorter ngrams. while len(masked_lms) + len(index_set) > num_to_predict: if n == 0: break index_set = sum(cand_index_set[n - 1], []) n -= 1 # If adding a whole-word mask would exceed the maximum number of # predictions, then just skip this candidate. if len(masked_lms) + len(index_set) > num_to_predict: continue is_any_index_covered = False for index in index_set: if index in covered_indexes: is_any_index_covered = True break if is_any_index_covered: continue for index in index_set: covered_indexes.add(index) masked_token = None token_id = tokens[index] if masking_style == "bert": # 80% of the time, replace with [MASK] if np_rng.random() < 0.8: masked_token = mask_id else: # 10% of the time, keep original if np_rng.random() < 0.5: masked_token = tokens[index] # 10% of the time, replace with random word else: masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] elif masking_style == "t5": masked_token = mask_id else: raise ValueError("invalid value of masking style") output_tokens[index] = masked_token masked_lms.append(MaskedLmInstance(index=index, label=token_id)) masked_spans.append(MaskedLmInstance( index=index_set, label=[tokens[index] for index in index_set])) assert len(masked_lms) <= num_to_predict np_rng.shuffle(ngram_indexes) select_indexes = set() if do_permutation: for cand_index_set in ngram_indexes: if len(select_indexes) >= num_to_predict: break if not cand_index_set: continue # Note(mingdachen): # Skip current piece if they are covered in lm masking or previous ngrams. for index_set in cand_index_set[0]: for index in index_set: if index in covered_indexes or index in select_indexes: continue n = np.random.choice(ngrams[:len(cand_index_set)], p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) index_set = sum(cand_index_set[n - 1], []) n -= 1 while len(select_indexes) + len(index_set) > num_to_predict: if n == 0: break index_set = sum(cand_index_set[n - 1], []) n -= 1 # If adding a whole-word mask would exceed the maximum number of # predictions, then just skip this candidate. if len(select_indexes) + len(index_set) > num_to_predict: continue is_any_index_covered = False for index in index_set: if index in covered_indexes or index in select_indexes: is_any_index_covered = True break if is_any_index_covered: continue for index in index_set: select_indexes.add(index) assert len(select_indexes) <= num_to_predict select_indexes = sorted(select_indexes) permute_indexes = list(select_indexes) np_rng.shuffle(permute_indexes) orig_token = list(output_tokens) for src_i, tgt_i in zip(select_indexes, permute_indexes): output_tokens[src_i] = orig_token[tgt_i] masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) masked_lms = sorted(masked_lms, key=lambda x: x.index) # Sort the spans by the index of the first span masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) for p in masked_lms: masked_lm_positions.append(p.index) masked_lm_labels.append(p.label) return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)