Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Most of the code here has been copied from: | |
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py | |
# with some modifications. | |
import math | |
import time | |
import collections | |
import numpy as np | |
import re | |
from fengshen.data.megatron_dataloader.utils import ( | |
print_rank_0 | |
) | |
from fengshen.data.megatron_dataloader.blendable_dataset import BlendableDataset | |
from fengshen.data.megatron_dataloader.indexed_dataset import make_dataset as make_indexed_dataset | |
DSET_TYPE_BERT = 'standard_bert' | |
DSET_TYPE_ICT = 'ict' | |
DSET_TYPE_T5 = 't5' | |
DSET_TYPE_BERT_CN_WWM = 'bert_cn_wwm' | |
DSET_TYPE_BART = 'bart' | |
DSET_TYPE_COCOLM = 'coco_lm' | |
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, | |
DSET_TYPE_T5, DSET_TYPE_BERT_CN_WWM, | |
DSET_TYPE_BART, DSET_TYPE_COCOLM] | |
def get_datasets_weights_and_num_samples(data_prefix, | |
train_valid_test_num_samples): | |
# The data prefix should be in the format of: | |
# weight-1, data-prefix-1, weight-2, data-prefix-2, .. | |
assert len(data_prefix) % 2 == 0 | |
num_datasets = len(data_prefix) // 2 | |
weights = [0] * num_datasets | |
prefixes = [0] * num_datasets | |
for i in range(num_datasets): | |
weights[i] = float(data_prefix[2 * i]) | |
prefixes[i] = (data_prefix[2 * i + 1]).strip() | |
# Normalize weights | |
weight_sum = 0.0 | |
for weight in weights: | |
weight_sum += weight | |
assert weight_sum > 0.0 | |
weights = [weight / weight_sum for weight in weights] | |
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does | |
# not uniformly distribute the number of samples, we still have | |
# samples left to feed to the network. | |
datasets_train_valid_test_num_samples = [] | |
for weight in weights: | |
datasets_train_valid_test_num_samples.append( | |
[int(math.ceil(val * weight * 1.005)) | |
for val in train_valid_test_num_samples]) | |
return prefixes, weights, datasets_train_valid_test_num_samples | |
def compile_helper(): | |
"""Compile helper function ar runtime. Make sure this | |
is invoked on a single process.""" | |
import os | |
import subprocess | |
path = os.path.abspath(os.path.dirname(__file__)) | |
ret = subprocess.run(['make', '-C', path]) | |
if ret.returncode != 0: | |
print("Making C++ dataset helpers module failed, exiting.") | |
import sys | |
sys.exit(1) | |
def get_a_and_b_segments(sample, np_rng): | |
"""Divide sample into a and b segments.""" | |
# Number of sentences in the sample. | |
n_sentences = len(sample) | |
# Make sure we always have two sentences. | |
assert n_sentences > 1, 'make sure each sample has at least two sentences.' | |
# First part: | |
# `a_end` is how many sentences go into the `A`. | |
a_end = 1 | |
if n_sentences >= 3: | |
# Note that randin in numpy is exclusive. | |
a_end = np_rng.randint(1, n_sentences) | |
tokens_a = [] | |
for j in range(a_end): | |
tokens_a.extend(sample[j]) | |
# Second part: | |
tokens_b = [] | |
for j in range(a_end, n_sentences): | |
tokens_b.extend(sample[j]) | |
# Random next: | |
is_next_random = False | |
if np_rng.random() < 0.5: | |
is_next_random = True | |
tokens_a, tokens_b = tokens_b, tokens_a | |
return tokens_a, tokens_b, is_next_random | |
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): | |
"""Truncates a pair of sequences to a maximum sequence length.""" | |
# print(len_a, len_b, max_num_tokens) | |
assert len_a > 0 | |
if len_a + len_b <= max_num_tokens: | |
return False | |
while len_a + len_b > max_num_tokens: | |
if len_a > len_b: | |
len_a -= 1 | |
tokens = tokens_a | |
else: | |
len_b -= 1 | |
tokens = tokens_b | |
if np_rng.random() < 0.5: | |
del tokens[0] | |
else: | |
tokens.pop() | |
return True | |
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): | |
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" | |
tokens = [] | |
tokentypes = [] | |
# [CLS]. | |
tokens.append(cls_id) | |
tokentypes.append(0) | |
# Segment A. | |
for token in tokens_a: | |
tokens.append(token) | |
tokentypes.append(0) | |
# [SEP]. | |
tokens.append(sep_id) | |
tokentypes.append(0) | |
# Segment B. | |
for token in tokens_b: | |
tokens.append(token) | |
tokentypes.append(1) | |
if tokens_b: | |
# [SEP]. | |
tokens.append(sep_id) | |
tokentypes.append(1) | |
return tokens, tokentypes | |
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", | |
["index", "label"]) | |
def is_start_piece(piece): | |
"""Check if the current word piece is the starting piece (BERT).""" | |
# When a word has been split into | |
# WordPieces, the first token does not have any marker and any subsequence | |
# tokens are prefixed with ##. So whenever we see the ## token, we | |
# append it to the previous set of word indexes. | |
return not piece.startswith("##") | |
def create_masked_lm_predictions(tokens, | |
vocab_id_list, vocab_id_to_token_dict, | |
masked_lm_prob, | |
cls_id, sep_id, mask_id, | |
max_predictions_per_seq, | |
np_rng, | |
tokenizer, | |
max_ngrams=3, | |
do_whole_word_mask=True, | |
favor_longer_ngram=False, | |
do_permutation=False, | |
geometric_dist=False, | |
masking_style="bert", | |
zh_tokenizer=None): | |
"""Creates the predictions for the masked LM objective. | |
Note: Tokens here are vocab ids and not text tokens.""" | |
cand_indexes = [] | |
# Note(mingdachen): We create a list for recording if the piece is | |
# the starting piece of current token, where 1 means true, so that | |
# on-the-fly whole word masking is possible. | |
token_boundary = [0] * len(tokens) | |
# 如果没有指定中文分词器,那就直接按##算 | |
if zh_tokenizer is None: | |
for (i, token) in enumerate(tokens): | |
if token == cls_id or token == sep_id: | |
token_boundary[i] = 1 | |
continue | |
# Whole Word Masking means that if we mask all of the wordpieces | |
# corresponding to an original word. | |
# | |
# Note that Whole Word Masking does *not* change the training code | |
# at all -- we still predict each WordPiece independently, softmaxed | |
# over the entire vocabulary. | |
if (do_whole_word_mask and len(cand_indexes) >= 1 and | |
not is_start_piece(vocab_id_to_token_dict[token])): | |
cand_indexes[-1].append(i) | |
else: | |
cand_indexes.append([i]) | |
if is_start_piece(vocab_id_to_token_dict[token]): | |
token_boundary[i] = 1 | |
else: | |
# 如果指定了中文分词器,那就先用分词器分词,然后再进行判断 | |
# 获取去掉CLS SEP的原始文本 | |
raw_tokens = [] | |
for t in tokens: | |
if t != cls_id and t != sep_id: | |
raw_tokens.append(t) | |
raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens] | |
# 分词然后获取每次字开头的最长词的长度 | |
word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True)) | |
word_length_dict = {} | |
for w in word_list: | |
if len(w) < 1: | |
continue | |
if w[0] not in word_length_dict: | |
word_length_dict[w[0]] = len(w) | |
elif word_length_dict[w[0]] < len(w): | |
word_length_dict[w[0]] = len(w) | |
i = 0 | |
# 从词表里面检索 | |
while i < len(tokens): | |
token_id = tokens[i] | |
token = vocab_id_to_token_dict[token_id] | |
if len(token) == 0 or token_id == cls_id or token_id == sep_id: | |
token_boundary[i] = 1 | |
i += 1 | |
continue | |
word_max_length = 1 | |
if token[0] in word_length_dict: | |
word_max_length = word_length_dict[token[0]] | |
j = 0 | |
word = '' | |
word_end = i+1 | |
# 兼容以前##的形式,如果后面的词是##开头的,那么直接把后面的拼到前面当作一个词 | |
old_style = False | |
while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'): | |
old_style = True | |
word_end += 1 | |
if not old_style: | |
while j < word_max_length and i+j < len(tokens): | |
cur_token = tokens[i+j] | |
word += vocab_id_to_token_dict[cur_token] | |
j += 1 | |
if word in word_list: | |
word_end = i+j | |
cand_indexes.append([p for p in range(i, word_end)]) | |
token_boundary[i] = 1 | |
i = word_end | |
output_tokens = list(tokens) | |
# add by ganruyi | |
if masking_style == 'bert-cn-wwm': | |
# if non chinese is False, that means it is chinese | |
# then try to remove "##" which is added previously | |
new_token_ids = [] | |
for token_id in output_tokens: | |
token = tokenizer.convert_ids_to_tokens([token_id])[0] | |
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: | |
token = token[2:] | |
new_token_id = tokenizer.convert_tokens_to_ids([token])[ | |
0] | |
new_token_ids.append(new_token_id) | |
output_tokens = new_token_ids | |
masked_lm_positions = [] | |
masked_lm_labels = [] | |
if masked_lm_prob == 0: | |
return (output_tokens, masked_lm_positions, | |
masked_lm_labels, token_boundary) | |
num_to_predict = min(max_predictions_per_seq, | |
max(1, int(round(len(tokens) * masked_lm_prob)))) | |
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) | |
if not geometric_dist: | |
# Note(mingdachen): | |
# By default, we set the probilities to favor shorter ngram sequences. | |
pvals = 1. / np.arange(1, max_ngrams + 1) | |
pvals /= pvals.sum(keepdims=True) | |
if favor_longer_ngram: | |
pvals = pvals[::-1] | |
# 获取一个ngram的idx,对于每个word,记录他的ngram的word | |
ngram_indexes = [] | |
for idx in range(len(cand_indexes)): | |
ngram_index = [] | |
for n in ngrams: | |
ngram_index.append(cand_indexes[idx:idx + n]) | |
ngram_indexes.append(ngram_index) | |
np_rng.shuffle(ngram_indexes) | |
(masked_lms, masked_spans) = ([], []) | |
covered_indexes = set() | |
for cand_index_set in ngram_indexes: | |
if len(masked_lms) >= num_to_predict: | |
break | |
if not cand_index_set: | |
continue | |
# Note(mingdachen): | |
# Skip current piece if they are covered in lm masking or previous ngrams. | |
for index_set in cand_index_set[0]: | |
for index in index_set: | |
if index in covered_indexes: | |
continue | |
if not geometric_dist: | |
n = np_rng.choice(ngrams[:len(cand_index_set)], | |
p=pvals[:len(cand_index_set)] / | |
pvals[:len(cand_index_set)].sum(keepdims=True)) | |
else: | |
# Sampling "n" from the geometric distribution and clipping it to | |
# the max_ngrams. Using p=0.2 default from the SpanBERT paper | |
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) | |
n = min(np_rng.geometric(0.2), max_ngrams) | |
index_set = sum(cand_index_set[n - 1], []) | |
n -= 1 | |
# Note(mingdachen): | |
# Repeatedly looking for a candidate that does not exceed the | |
# maximum number of predictions by trying shorter ngrams. | |
while len(masked_lms) + len(index_set) > num_to_predict: | |
if n == 0: | |
break | |
index_set = sum(cand_index_set[n - 1], []) | |
n -= 1 | |
# If adding a whole-word mask would exceed the maximum number of | |
# predictions, then just skip this candidate. | |
if len(masked_lms) + len(index_set) > num_to_predict: | |
continue | |
is_any_index_covered = False | |
for index in index_set: | |
if index in covered_indexes: | |
is_any_index_covered = True | |
break | |
if is_any_index_covered: | |
continue | |
for index in index_set: | |
covered_indexes.add(index) | |
masked_token = None | |
if masking_style == "bert": | |
# 80% of the time, replace with [MASK] | |
if np_rng.random() < 0.8: | |
masked_token = mask_id | |
else: | |
# 10% of the time, keep original | |
if np_rng.random() < 0.5: | |
masked_token = tokens[index] | |
# 10% of the time, replace with random word | |
else: | |
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] | |
elif masking_style == 'bert-cn-wwm': | |
# 80% of the time, replace with [MASK] | |
if np_rng.random() < 0.8: | |
masked_token = mask_id | |
else: | |
# 10% of the time, keep original | |
if np_rng.random() < 0.5: | |
# 如果是中文全词mask,去掉tokens里的## | |
token_id = tokens[index] | |
token = tokenizer.convert_ids_to_tokens([token_id])[ | |
0] | |
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: | |
token = token[2:] | |
new_token_id = tokenizer.convert_tokens_to_ids([token])[ | |
0] | |
masked_token = new_token_id | |
# 10% of the time, replace with random word | |
else: | |
masked_token = vocab_id_list[np_rng.randint( | |
0, len(vocab_id_list))] | |
elif masking_style == "t5": | |
masked_token = mask_id | |
else: | |
raise ValueError("invalid value of masking style") | |
output_tokens[index] = masked_token | |
masked_lms.append(MaskedLmInstance( | |
index=index, label=tokens[index])) | |
masked_spans.append(MaskedLmInstance( | |
index=index_set, | |
label=[tokens[index] for index in index_set])) | |
assert len(masked_lms) <= num_to_predict | |
np_rng.shuffle(ngram_indexes) | |
select_indexes = set() | |
if do_permutation: | |
for cand_index_set in ngram_indexes: | |
if len(select_indexes) >= num_to_predict: | |
break | |
if not cand_index_set: | |
continue | |
# Note(mingdachen): | |
# Skip current piece if they are covered in lm masking or previous ngrams. | |
for index_set in cand_index_set[0]: | |
for index in index_set: | |
if index in covered_indexes or index in select_indexes: | |
continue | |
n = np.random.choice(ngrams[:len(cand_index_set)], | |
p=pvals[:len(cand_index_set)] / | |
pvals[:len(cand_index_set)].sum(keepdims=True)) | |
index_set = sum(cand_index_set[n - 1], []) | |
n -= 1 | |
while len(select_indexes) + len(index_set) > num_to_predict: | |
if n == 0: | |
break | |
index_set = sum(cand_index_set[n - 1], []) | |
n -= 1 | |
# If adding a whole-word mask would exceed the maximum number of | |
# predictions, then just skip this candidate. | |
if len(select_indexes) + len(index_set) > num_to_predict: | |
continue | |
is_any_index_covered = False | |
for index in index_set: | |
if index in covered_indexes or index in select_indexes: | |
is_any_index_covered = True | |
break | |
if is_any_index_covered: | |
continue | |
for index in index_set: | |
select_indexes.add(index) | |
assert len(select_indexes) <= num_to_predict | |
select_indexes = sorted(select_indexes) | |
permute_indexes = list(select_indexes) | |
np_rng.shuffle(permute_indexes) | |
orig_token = list(output_tokens) | |
for src_i, tgt_i in zip(select_indexes, permute_indexes): | |
output_tokens[src_i] = orig_token[tgt_i] | |
masked_lms.append(MaskedLmInstance( | |
index=src_i, label=orig_token[src_i])) | |
masked_lms = sorted(masked_lms, key=lambda x: x.index) | |
# Sort the spans by the index of the first span | |
masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) | |
for p in masked_lms: | |
masked_lm_positions.append(p.index) | |
masked_lm_labels.append(p.label) | |
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) | |
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, | |
masked_labels, pad_id, max_seq_length): | |
"""Pad sequences and convert them to numpy.""" | |
# Some checks. | |
num_tokens = len(tokens) | |
padding_length = max_seq_length - num_tokens | |
assert padding_length >= 0 | |
assert len(tokentypes) == num_tokens | |
assert len(masked_positions) == len(masked_labels) | |
# Tokens and token types. | |
filler = [pad_id] * padding_length | |
tokens_np = np.array(tokens + filler, dtype=np.int64) | |
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) | |
# Padding mask. | |
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, | |
dtype=np.int64) | |
# Lables and loss mask. | |
labels = [-1] * max_seq_length | |
loss_mask = [0] * max_seq_length | |
for i in range(len(masked_positions)): | |
assert masked_positions[i] < num_tokens | |
labels[masked_positions[i]] = masked_labels[i] | |
loss_mask[masked_positions[i]] = 1 | |
labels_np = np.array(labels, dtype=np.int64) | |
loss_mask_np = np.array(loss_mask, dtype=np.int64) | |
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np | |
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | |
train_valid_test_num_samples, | |
max_seq_length, | |
masked_lm_prob, short_seq_prob, seed, | |
tokenizer, | |
skip_warmup, binary_head=False, | |
max_seq_length_dec=None, | |
dataset_type='standard_bert', | |
zh_tokenizer=None, | |
span=None): | |
if len(data_prefix) == 1: | |
return _build_train_valid_test_datasets(data_prefix[0], | |
data_impl, splits_string, | |
train_valid_test_num_samples, | |
max_seq_length, masked_lm_prob, | |
short_seq_prob, seed, | |
skip_warmup, | |
binary_head, | |
max_seq_length_dec, | |
tokenizer, | |
dataset_type=dataset_type, | |
zh_tokenizer=zh_tokenizer, | |
span=span) | |
# Blending dataset. | |
# Parse the values. | |
output = get_datasets_weights_and_num_samples(data_prefix, | |
train_valid_test_num_samples) | |
prefixes, weights, datasets_train_valid_test_num_samples = output | |
# Build individual datasets. | |
train_datasets = [] | |
valid_datasets = [] | |
test_datasets = [] | |
for i in range(len(prefixes)): | |
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( | |
prefixes[i], data_impl, splits_string, | |
datasets_train_valid_test_num_samples[i], | |
max_seq_length, masked_lm_prob, short_seq_prob, | |
seed, skip_warmup, binary_head, max_seq_length_dec, | |
tokenizer, dataset_type=dataset_type, zh_tokenizer=zh_tokenizer) | |
if train_ds: | |
train_datasets.append(train_ds) | |
if valid_ds: | |
valid_datasets.append(valid_ds) | |
if test_ds: | |
test_datasets.append(test_ds) | |
# Blend. | |
blending_train_dataset = None | |
if train_datasets: | |
blending_train_dataset = BlendableDataset(train_datasets, weights) | |
blending_valid_dataset = None | |
if valid_datasets: | |
blending_valid_dataset = BlendableDataset(valid_datasets, weights) | |
blending_test_dataset = None | |
if test_datasets: | |
blending_test_dataset = BlendableDataset(test_datasets, weights) | |
return (blending_train_dataset, blending_valid_dataset, | |
blending_test_dataset) | |
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | |
train_valid_test_num_samples, | |
max_seq_length, | |
masked_lm_prob, short_seq_prob, seed, | |
skip_warmup, binary_head, | |
max_seq_length_dec, | |
tokenizer, | |
dataset_type='standard_bert', | |
zh_tokenizer=None, | |
span=None): | |
if dataset_type not in DSET_TYPES: | |
raise ValueError("Invalid dataset_type: ", dataset_type) | |
# Indexed dataset. | |
indexed_dataset = get_indexed_dataset_(data_prefix, | |
data_impl, | |
skip_warmup) | |
# Get start and end indices of train/valid/train into doc-idx | |
# Note that doc-idx is desinged to be num-docs + 1 so we can | |
# easily iterate over it. | |
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 | |
splits = get_train_valid_test_split_(splits_string, total_num_of_documents) | |
# Print stats about the splits. | |
print_rank_0(' > dataset split:') | |
def print_split_stats(name, index): | |
print_rank_0(' {}:'.format(name)) | |
print_rank_0(' document indices in [{}, {}) total of {} ' | |
'documents'.format(splits[index], splits[index + 1], | |
splits[index + 1] - splits[index])) | |
start_index = indexed_dataset.doc_idx[splits[index]] | |
end_index = indexed_dataset.doc_idx[splits[index + 1]] | |
print_rank_0(' sentence indices in [{}, {}) total of {} ' | |
'sentences'.format(start_index, end_index, | |
end_index - start_index)) | |
print_split_stats('train', 0) | |
print_split_stats('validation', 1) | |
print_split_stats('test', 2) | |
def build_dataset(index, name): | |
from fengshen.data.megatron_dataloader.bert_dataset import BertDataset | |
from fengshen.data.megatron_dataloader.bart_dataset import BartDataset | |
from fengshen.data.megatron_dataloader.cocolm_dataset import COCOLMDataset | |
dataset = None | |
if splits[index + 1] > splits[index]: | |
# Get the pointer to the original doc-idx so we can set it later. | |
doc_idx_ptr = indexed_dataset.get_doc_idx() | |
# Slice the doc-idx | |
start_index = splits[index] | |
# Add +1 so we can index into the dataset to get the upper bound. | |
end_index = splits[index + 1] + 1 | |
# New doc_idx view. | |
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) | |
# Build the dataset accordingly. | |
kwargs = dict( | |
name=name, | |
data_prefix=data_prefix, | |
num_epochs=None, | |
max_num_samples=train_valid_test_num_samples[index], | |
max_seq_length=max_seq_length, | |
seed=seed, | |
) | |
if dataset_type == DSET_TYPE_BERT or dataset_type == DSET_TYPE_BERT_CN_WWM: | |
dataset = BertDataset( | |
indexed_dataset=indexed_dataset, | |
masked_lm_prob=masked_lm_prob, | |
short_seq_prob=short_seq_prob, | |
binary_head=binary_head, | |
# 增加参数区分bert和bert-cn-wwm | |
tokenizer=tokenizer, | |
masking_style='bert' if dataset_type == DSET_TYPE_BERT else 'bert-cn-wwm', | |
**kwargs | |
) | |
elif dataset_type == DSET_TYPE_BART: | |
dataset = BartDataset( | |
indexed_dataset=indexed_dataset, | |
masked_lm_prob=masked_lm_prob, | |
short_seq_prob=short_seq_prob, | |
tokenizer=tokenizer, | |
zh_tokenizer=zh_tokenizer, | |
**kwargs | |
) | |
elif dataset_type == DSET_TYPE_COCOLM: | |
dataset = COCOLMDataset( | |
indexed_dataset=indexed_dataset, | |
masked_lm_prob=masked_lm_prob, | |
short_seq_prob=short_seq_prob, | |
tokenizer=tokenizer, | |
masking_style='bert', | |
span=span, | |
**kwargs | |
) | |
else: | |
raise NotImplementedError( | |
"Dataset type not fully implemented.") | |
# Set the original pointer so dataset remains the main dataset. | |
indexed_dataset.set_doc_idx(doc_idx_ptr) | |
# Checks. | |
assert indexed_dataset.doc_idx[0] == 0 | |
assert indexed_dataset.doc_idx.shape[0] == \ | |
(total_num_of_documents + 1) | |
return dataset | |
train_dataset = build_dataset(0, 'train') | |
valid_dataset = build_dataset(1, 'valid') | |
test_dataset = build_dataset(2, 'test') | |
return (train_dataset, valid_dataset, test_dataset) | |
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): | |
print_rank_0(' > building dataset index ...') | |
start_time = time.time() | |
indexed_dataset = make_indexed_dataset(data_prefix, | |
data_impl, | |
skip_warmup) | |
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] | |
print_rank_0(' > finished creating indexed dataset in {:4f} ' | |
'seconds'.format(time.time() - start_time)) | |
print_rank_0(' > indexed dataset stats:') | |
print_rank_0(' number of documents: {}'.format( | |
indexed_dataset.doc_idx.shape[0] - 1)) | |
print_rank_0(' number of sentences: {}'.format( | |
indexed_dataset.sizes.shape[0])) | |
return indexed_dataset | |
def get_train_valid_test_split_(splits_string, size): | |
""" Get dataset splits from comma or '/' separated string list.""" | |
splits = [] | |
if splits_string.find(',') != -1: | |
splits = [float(s) for s in splits_string.split(',')] | |
elif splits_string.find('/') != -1: | |
splits = [float(s) for s in splits_string.split('/')] | |
else: | |
splits = [float(splits_string)] | |
while len(splits) < 3: | |
splits.append(0.) | |
splits = splits[:3] | |
splits_sum = sum(splits) | |
assert splits_sum > 0.0 | |
splits = [split / splits_sum for split in splits] | |
splits_index = [0] | |
for index, split in enumerate(splits): | |
splits_index.append(splits_index[index] + | |
int(round(split * float(size)))) | |
diff = splits_index[-1] - size | |
for index in range(1, len(splits_index)): | |
splits_index[index] -= diff | |
assert len(splits_index) == 4 | |
assert splits_index[-1] == size | |
return splits_index | |
def get_samples_mapping(indexed_dataset, | |
data_prefix, | |
num_epochs, | |
max_num_samples, | |
max_seq_length, | |
short_seq_prob, | |
seed, | |
name, | |
binary_head): | |
"""Get a list that maps a sample index to a starting | |
sentence index, end sentence index, and length""" | |
if not num_epochs: | |
if not max_num_samples: | |
raise ValueError("Need to specify either max_num_samples " | |
"or num_epochs") | |
num_epochs = np.iinfo(np.int32).max - 1 | |
if not max_num_samples: | |
max_num_samples = np.iinfo(np.int64).max - 1 | |
# Filename of the index mapping | |
indexmap_filename = data_prefix | |
indexmap_filename += '_{}_indexmap'.format(name) | |
if num_epochs != (np.iinfo(np.int32).max - 1): | |
indexmap_filename += '_{}ep'.format(num_epochs) | |
if max_num_samples != (np.iinfo(np.int64).max - 1): | |
indexmap_filename += '_{}mns'.format(max_num_samples) | |
indexmap_filename += '_{}msl'.format(max_seq_length) | |
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) | |
indexmap_filename += '_{}s'.format(seed) | |
indexmap_filename += '.npy' | |
# This should be a barrier but nccl barrier assumes | |
# device_index=rank which is not the case for model | |
# parallel case | |
# ganruyi comment | |
# counts = torch.cuda.LongTensor([1]) | |
# torch.distributed.all_reduce( | |
# counts, group=mpu.get_data_parallel_group()) | |
# torch.distributed.all_reduce( | |
# counts, group=mpu.get_pipeline_model_parallel_group()) | |
# assert counts[0].item() == ( | |
# torch.distributed.get_world_size() // | |
# torch.distributed.get_world_size( | |
# group=mpu.get_tensor_model_parallel_group())) | |
# Load indexed dataset. | |
print_rank_0(' > loading indexed mapping from {}'.format( | |
indexmap_filename)) | |
start_time = time.time() | |
samples_mapping = np.load( | |
indexmap_filename, allow_pickle=True, mmap_mode='r') | |
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( | |
time.time() - start_time)) | |
print_rank_0(' total number of samples: {}'.format( | |
samples_mapping.shape[0])) | |
return samples_mapping | |