Spaces:
Runtime error
Runtime error
"""BART Style dataset. Modified from fairseq.""" | |
import numpy as np | |
import torch | |
import math | |
import re | |
from fengshen.data.megatron_dataloader.dataset_utils import ( | |
get_samples_mapping | |
) | |
class BartDataset(torch.utils.data.Dataset): | |
def __init__(self, name, indexed_dataset, data_prefix, | |
num_epochs, max_num_samples, masked_lm_prob, | |
max_seq_length, short_seq_prob, seed, tokenizer, zh_tokenizer): | |
# Params to store. | |
self.name = name | |
self.seed = seed | |
self.masked_lm_prob = masked_lm_prob | |
self.max_seq_length = max_seq_length | |
# Dataset. | |
self.indexed_dataset = indexed_dataset | |
# Build the samples mapping. | |
self.samples_mapping = get_samples_mapping(self.indexed_dataset, | |
data_prefix, | |
num_epochs, | |
max_num_samples, | |
self.max_seq_length - 3, # account for added tokens | |
short_seq_prob, | |
self.seed, | |
self.name, | |
False) | |
# Vocab stuff. | |
self.vocab_size = tokenizer.vocab_size | |
inv_vocab = {v: k for k, v in tokenizer.vocab.items()} | |
self.vocab_id_list = list(inv_vocab.keys()) | |
self.vocab_id_to_token_dict = inv_vocab | |
self.cls_id = tokenizer.cls_token_id | |
self.sep_id = tokenizer.sep_token_id | |
self.mask_id = tokenizer.mask_token_id | |
self.pad_id = tokenizer.pad_token_id | |
self.tokenizer = tokenizer | |
seg_tokens = ['。', ';', ';', '!', '!', '?', '?'] | |
seg_token_ids = [] | |
for t in seg_tokens: | |
if t in tokenizer.vocab: | |
seg_token_ids.append(tokenizer.vocab[t]) | |
else: | |
print('seg_token "{}" not in vocab'.format(t)) | |
self.seg_token_ids = set(seg_token_ids) | |
self.zh_tokenizer = zh_tokenizer | |
# Denoising ratios | |
self.permute_sentence_ratio = 1.0 | |
self.mask_ratio = masked_lm_prob # 0.15 | |
self.random_ratio = 0.1 | |
self.insert_ratio = 0.0 | |
self.rotate_ratio = 0.0 | |
self.mask_whole_word = 1 | |
self.item_transform_func = None | |
self.mask_span_distribution = None | |
if False: | |
_lambda = 3 # Poisson lambda | |
lambda_to_the_k = 1 | |
e_to_the_minus_lambda = math.exp(-_lambda) | |
k_factorial = 1 | |
ps = [] | |
for k in range(0, 128): | |
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) | |
lambda_to_the_k *= _lambda | |
k_factorial *= k + 1 | |
if ps[-1] < 0.0000001: | |
break | |
ps = torch.FloatTensor(ps) | |
self.mask_span_distribution = torch.distributions.Categorical(ps) | |
def __len__(self): | |
return self.samples_mapping.shape[0] | |
def __getitem__(self, idx): | |
start_idx, end_idx, seq_length = self.samples_mapping[idx] | |
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] | |
# Note that this rng state should be numpy and not python since | |
# python randint is inclusive whereas the numpy one is exclusive. | |
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1 | |
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) | |
return self.build_training_sample(sample, self.max_seq_length, np_rng) | |
def build_training_sample(self, sample, max_seq_length, np_rng): | |
"""Biuld training sample. | |
Arguments: | |
sample: A list of sentences in which each sentence is a list token ids. | |
max_seq_length: Desired sequence length. | |
np_rng: Random number genenrator. Note that this rng state should be | |
numpy and not python since python randint is inclusive for | |
the opper bound whereas the numpy one is exclusive. | |
""" | |
# permute sentences | |
full_stops = [] | |
tokens = [self.cls_id] | |
for sent in sample: | |
for t in sent: | |
token = self.vocab_id_to_token_dict[t] | |
if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: | |
# 兼容erlangshen ##的方式做whole word mask | |
t = self.tokenizer.convert_tokens_to_ids(token[2:]) | |
tokens.append(t) | |
if t in self.seg_token_ids: | |
tokens.append(self.sep_id) | |
if tokens[-1] != self.sep_id: | |
tokens.append(self.sep_id) | |
if len(tokens) > max_seq_length: | |
tokens = tokens[:max_seq_length] | |
tokens[-1] = self.sep_id | |
tokens = torch.LongTensor(tokens) | |
full_stops = (tokens == self.sep_id).long() | |
assert (max_seq_length - tokens.shape[0]) >= 0, (tokens.size(), tokens[-1], max_seq_length) | |
source, target = tokens, tokens[1:].clone() | |
use_decoder = 1 | |
# if torch.rand(1).item() < 0.5: | |
# use_decoder = 0 | |
if self.permute_sentence_ratio > 0.0 and use_decoder == 1: | |
source = self.permute_sentences(source, full_stops, self.permute_sentence_ratio) | |
if self.mask_ratio > 0.0: | |
replace_length = 1 if use_decoder else -1 | |
mask_ratio = self.mask_ratio * 2 if use_decoder else self.mask_ratio | |
source = self.add_whole_word_mask(source, mask_ratio, replace_length) | |
if self.insert_ratio > 0.0: | |
raise NotImplementedError | |
source = self.add_insertion_noise(source, self.insert_ratio) | |
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: | |
raise NotImplementedError | |
source = self.add_rolling_noise(source) | |
# there can additional changes to make: | |
if self.item_transform_func is not None: | |
source, target = self.item_transform_func(source, target) | |
assert (source >= 0).all() | |
# assert (source[1:-1] >= 1).all() | |
assert (source <= self.vocab_size).all() | |
assert source[0] == self.cls_id | |
assert source[-1] == self.sep_id | |
# tokenizer = get_tokenizer() | |
# print(' '.join(tokenizer.tokenizer.convert_ids_to_tokens(source))) | |
# print(tokenizer.detokenize(target)) | |
# print(tokenizer.detokenize(source)) | |
# print() | |
prev_output_tokens = torch.zeros_like(target) | |
prev_output_tokens[0] = self.sep_id # match the preprocessing in fairseq | |
prev_output_tokens[1:] = target[:-1] | |
# src_padding_length = max_seq_length - source.shape[0] | |
# tgt_padding_length = max_seq_length - target.shape[0] | |
# assert src_padding_length >= 0, (source.size(), source[-1], max_seq_length) | |
# assert tgt_padding_length >= 0, (target.size(), target[-1], max_seq_length) | |
source_ = torch.full((max_seq_length,), self.pad_id, dtype=torch.long) | |
source_[:source.shape[0]] = source | |
target_ = torch.full((max_seq_length,), -100, dtype=torch.long) | |
# decoder not need bos in the front | |
target_[:target.shape[0]] = target | |
prev_output_tokens_ = torch.full((max_seq_length,), self.pad_id, dtype=torch.long) | |
prev_output_tokens_[:prev_output_tokens.shape[0]] = prev_output_tokens | |
return { | |
"input_ids": source_, | |
"labels": target_, | |
# "decoder_input_ids": prev_output_tokens_, | |
"attention_mask": (source_ != self.pad_id).long() | |
} | |
def permute_sentences(self, source, full_stops, p=1.0): | |
# Tokens that are full stops, where the previous token is not | |
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 | |
result = source.clone() | |
num_sentences = sentence_ends.size(0) | |
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) | |
substitutions = torch.randperm(num_sentences)[:num_to_permute] | |
ordering = torch.arange(0, num_sentences) | |
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] | |
# Ignore <bos> at start | |
index = 1 | |
for i in ordering: | |
sentence = source[(sentence_ends[i - 1] if i > 0 else 1): sentence_ends[i]] | |
result[index: index + sentence.size(0)] = sentence | |
index += sentence.size(0) | |
return result | |
def word_starts_en(self, source): | |
if self.mask_whole_word is not None: | |
is_word_start = self.mask_whole_word.gather(0, source) | |
else: | |
is_word_start = torch.ones(source.size()) | |
is_word_start[0] = 0 | |
is_word_start[-1] = 0 | |
return is_word_start | |
def word_starts(self, source): | |
if self.mask_whole_word is None: | |
is_word_start = torch.ones(source.size()) | |
is_word_start[0] = 0 | |
is_word_start[-1] = 0 | |
return is_word_start | |
raw_tokens = [self.vocab_id_to_token_dict[i] for i in source.tolist()] | |
words = [raw_tokens[0]] + \ | |
self.zh_tokenizer(''.join(raw_tokens[1:-1]), HMM=True) + [raw_tokens[-1]] | |
def _is_chinese_char(c): | |
"""Checks whether CP is the #codepoint of a CJK character.""" | |
# This defines a "chinese character" as anything in the CJK Unicode block: | |
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |
# | |
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |
# despite its name. The modern Korean Hangul alphabet is a different block, | |
# as is Japanese Hiragana and Katakana. Those alphabets are used to write | |
# space-separated words, so they are not treated specially and handled | |
# like the all of the other languages. | |
if len(c) > 1: | |
return all([_is_chinese_char(c_i) for c_i in c]) | |
cp = ord(c) | |
if ((cp >= 0x4E00 and cp <= 0x9FFF) or # | |
(cp >= 0x3400 and cp <= 0x4DBF) or # | |
(cp >= 0x20000 and cp <= 0x2A6DF) or # | |
(cp >= 0x2A700 and cp <= 0x2B73F) or # | |
(cp >= 0x2B740 and cp <= 0x2B81F) or # | |
(cp >= 0x2B820 and cp <= 0x2CEAF) or | |
(cp >= 0xF900 and cp <= 0xFAFF) or # | |
(cp >= 0x2F800 and cp <= 0x2FA1F)): # | |
return True | |
return False | |
def align_linear(atokens, btokens): | |
a2c = [] | |
c2b = [] | |
a2b = [] | |
length = 0 | |
for tok in atokens: | |
a2c.append([length + i for i in range(len(tok))]) | |
length += len(tok) | |
for i, tok in enumerate(btokens): | |
c2b.extend([i for _ in range(len(tok))]) | |
for i, amap in enumerate(a2c): | |
bmap = [c2b[ci] for ci in amap] | |
a2b.append(list(set(bmap))) | |
return a2b | |
raw_to_word_align = align_linear(raw_tokens, words) | |
is_word_start = torch.zeros(source.size()) | |
word_starts = [] | |
skip_cur_word = True | |
for i in range(1, len(raw_to_word_align)): | |
if raw_to_word_align[i-1] == raw_to_word_align[i]: | |
# not a word start, as they align to the same word | |
if not skip_cur_word and not _is_chinese_char(raw_tokens[i]): | |
word_starts.pop(-1) | |
skip_cur_word = True | |
continue | |
else: | |
is_word_start[i] = 1 | |
if _is_chinese_char(raw_tokens[i]): | |
word_starts.append(i) | |
skip_cur_word = False | |
is_word_start[0] = 0 | |
is_word_start[-1] = 0 | |
word_starts = torch.tensor(word_starts).long().view(-1, 1) | |
return is_word_start, word_starts | |
def add_whole_word_mask(self, source, p, replace_length=1): | |
is_word_start, word_starts = self.word_starts(source) | |
num_to_mask_word = int(math.ceil(word_starts.size(0) * p)) | |
num_to_mask_char = int(math.ceil(word_starts.size(0) * p * 0.1)) | |
num_to_mask = num_to_mask_word + num_to_mask_char | |
if num_to_mask > word_starts.size(0): | |
word_starts = is_word_start.nonzero(as_tuple=False) | |
num_inserts = 0 | |
if num_to_mask == 0: | |
return source | |
if self.mask_span_distribution is not None: | |
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) | |
# Make sure we have enough to mask | |
cum_length = torch.cumsum(lengths, 0) | |
while cum_length[-1] < num_to_mask: | |
lengths = torch.cat( | |
[ | |
lengths, | |
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), | |
], | |
dim=0, | |
) | |
cum_length = torch.cumsum(lengths, 0) | |
# Trim to masking budget | |
i = 0 | |
while cum_length[i] < num_to_mask: | |
i += 1 | |
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) | |
num_to_mask = i + 1 | |
lengths = lengths[:num_to_mask] | |
# Handle 0-length mask (inserts) separately | |
lengths = lengths[lengths > 0] | |
num_inserts = num_to_mask - lengths.size(0) | |
num_to_mask -= num_inserts | |
if num_to_mask == 0: | |
return self.add_insertion_noise(source, num_inserts / source.size(0)) | |
assert (lengths > 0).all() | |
else: | |
lengths = torch.ones((num_to_mask,)).long() | |
assert is_word_start[-1] == 0 | |
indices = word_starts[ | |
torch.randperm(word_starts.size(0))[:num_to_mask] | |
].squeeze(1) | |
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio | |
source_length = source.size(0) | |
assert source_length - 1 not in indices | |
to_keep = torch.ones(source_length, dtype=torch.bool) | |
is_word_start[ | |
-1 | |
] = 255 # acts as a long length, so spans don't go over the end of doc | |
if replace_length == 0: | |
to_keep[indices] = 0 | |
else: | |
# keep index, but replace it with [MASK] | |
# print(source.size(), word_starts.size(), indices.size(), mask_random.size()) | |
source[indices] = self.mask_id | |
source[indices[mask_random]] = torch.randint( | |
1, self.vocab_size, size=(mask_random.sum(),) | |
) | |
# sorted_indices = torch.sort(indices)[0] | |
# continue_mask_pos = ((sorted_indices + 1)[:-1] == sorted_indices[1:]) | |
# continue_mask_indices = sorted_indices[1:][continue_mask_pos] | |
# to_keep[continue_mask_indices] = 0 | |
# for char indices, we already masked, the following loop handles word mask | |
indices = indices[:num_to_mask_word] | |
mask_random = mask_random[:num_to_mask_word] | |
if self.mask_span_distribution is not None: | |
assert len(lengths.size()) == 1 | |
assert lengths.size() == indices.size() | |
lengths -= 1 | |
while indices.size(0) > 0: | |
assert lengths.size() == indices.size() | |
lengths -= is_word_start[indices + 1].long() | |
uncompleted = lengths >= 0 | |
indices = indices[uncompleted] + 1 | |
mask_random = mask_random[uncompleted] | |
lengths = lengths[uncompleted] | |
if replace_length != -1: | |
# delete token | |
to_keep[indices] = 0 | |
else: | |
# keep index, but replace it with [MASK] | |
source[indices] = self.mask_id | |
source[indices[mask_random]] = torch.randint( | |
1, self.vocab_size, size=(mask_random.sum(),) | |
) | |
else: | |
# A bit faster when all lengths are 1 | |
while indices.size(0) > 0: | |
uncompleted = is_word_start[indices + 1] == 0 | |
indices = indices[uncompleted] + 1 | |
mask_random = mask_random[uncompleted] | |
if replace_length != -1: | |
# delete token | |
to_keep[indices] = 0 | |
else: | |
# keep index, but replace it with [MASK] | |
source[indices] = self.mask_id | |
source[indices[mask_random]] = torch.randint( | |
1, self.vocab_size, size=(mask_random.sum(),) | |
) | |
assert source_length - 1 not in indices | |
source = source[to_keep] | |
if num_inserts > 0: | |
source = self.add_insertion_noise(source, num_inserts / source.size(0)) | |
return source | |
def add_permuted_noise(self, tokens, p): | |
num_words = len(tokens) | |
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0) | |
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1 | |
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]] | |
return tokens | |
def add_rolling_noise(self, tokens): | |
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) | |
tokens = torch.cat( | |
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), | |
dim=0, | |
) | |
return tokens | |
def add_insertion_noise(self, tokens, p): | |
if p == 0.0: | |
return tokens | |
num_tokens = len(tokens) | |
n = int(math.ceil(num_tokens * p)) | |
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 | |
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) | |
noise_mask[noise_indices] = 1 | |
result = torch.LongTensor(n + len(tokens)).fill_(-1) | |
num_random = int(math.ceil(n * self.random_ratio)) | |
result[noise_indices[num_random:]] = self.mask_id | |
result[noise_indices[:num_random]] = torch.randint( | |
low=1, high=self.vocab_size, size=(num_random,) | |
) | |
result[~noise_mask] = tokens | |
assert (result >= 0).all() | |
return result | |