khulnasoft's picture
Upload 108 files
4fb0bd1 verified
import random
import logging
logger = logging.getLogger(__name__)
class Dataset():
"""This class constructs dataset for multiple date file
"""
def __init__(self, name, instance_dict=dict()):
"""This function initializes a dataset,
define dataset name, this dataset contains multiple readers, as datafiles.
Arguments:
name {str} -- dataset name
Keyword Arguments:
instance_dict {dict} -- instance settings (default: {dict()})
"""
self.dataset_name = name
self.datasets = dict()
self.instance_dict = dict(instance_dict)
def add_instance(self, name, instance, reader, is_count=False, is_train=False):
"""This function adds a instance to dataset
Arguments:
name {str} -- intance name
instance {Instance} -- instance
reader {DatasetReader} -- reader correspond to instance
Keyword Arguments:
is_count {bool} -- instance paticipates in counting or not (default: {False})
is_train {bool} -- instance is training data or not (default: {False})
"""
self.instance_dict[name] = {
'instance': instance,
'reader': reader,
'is_count': is_count,
'is_train': is_train
}
def build_dataset(self,
vocab,
counter=None,
min_count=dict(),
pretrained_vocab=None,
intersection_namespace=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict(),
tokens_to_add=None):
"""This function bulids dataset
Arguments:
vocab {Vocabulary} -- vocabulary
Keyword Arguments:
counter {dict} -- counter (default: {None})
min_count {dict} -- min count for each namespace (default: {dict()})
pretrained_vocab {dict} -- pretrained vocabulary (default: {None})
intersection_namespace {dict} -- intersection vocabulary namespace correspond to
pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
no_pad_namespace {list} -- no padding vocabulary namespace (default: {list()})
no_unk_namespace {list} -- no unknown vocabulary namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token vocabulary namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token vocabulary namespace (default: {dict()})
tokens_to_add {dict} -- tokens need to be added to vocabulary (default: {None})
"""
# construct counter
if counter is not None:
for instance_name, instance_settting in self.instance_dict.items():
if instance_settting['is_count']:
instance_settting['instance'].count_vocab_items(counter,
instance_settting['reader'])
# construct vocabulary from counter
vocab.extend_from_counter(counter, min_count, no_pad_namespace, no_unk_namespace,
contain_pad_namespace, contain_unk_namespace)
# add extra tokens, this operation should be executeed before adding pretrained_vocab
if tokens_to_add is not None:
for namespace, tokens in tokens_to_add.items():
vocab.add_tokens_to_namespace(tokens, namespace)
# construct vocabulary from pretained vocabulary
if pretrained_vocab is not None:
vocab.extend_from_pretrained_vocab(pretrained_vocab, intersection_namespace,
no_pad_namespace, no_unk_namespace,
contain_pad_namespace, contain_unk_namespace)
self.vocab = vocab
for instance_name, instance_settting in self.instance_dict.items():
instance_settting['instance'].index(self.vocab, instance_settting['reader'])
self.datasets[instance_name] = instance_settting['instance'].get_instance()
self.instance_dict[instance_name]['size'] = instance_settting['instance'].get_size()
self.instance_dict[instance_name]['vocab_dict'] = instance_settting[
'instance'].get_vocab_dict()
logger.info("{} dataset size: {}.".format(instance_name,
self.instance_dict[instance_name]['size']))
for key, seq_len in instance_settting['reader'].get_seq_lens().items():
logger.info("{} dataset's {}: max_len={}, min_len={}.".format(
instance_name, key, max(seq_len), min(seq_len)))
def get_batch(self, instance_name, batch_size, sort_namespace=None):
"""get_batch gets batch data and padding
Arguments:
instance_name {str} -- instance name
batch_size {int} -- batch size
Keyword Arguments:
sort_namespace {str} -- sort samples key, meanwhile calculate sequence length if not None, while keep None means that no sorting (default: {None})
Yields:
int -- epoch
dict -- batch data
"""
if instance_name not in self.instance_dict:
logger.error('can not find instance name {} in datasets.'.format(instance_name))
return
dataset = self.datasets[instance_name]
if sort_namespace is not None and sort_namespace not in dataset:
logger.error('can not find sort namespace {} in datasets instance {}.'.format(
sort_namespace, instance_name))
size = self.instance_dict[instance_name]['size']
vocab_dict = self.instance_dict[instance_name]['vocab_dict']
ids = list(range(size))
if self.instance_dict[instance_name]['is_train']:
random.shuffle(ids)
epoch = 1
cur = 0
while True:
if cur >= size:
epoch += 1
if not self.instance_dict[instance_name]['is_train'] and epoch > 1:
break
random.shuffle(ids)
cur = 0
sample_ids = ids[cur:cur + batch_size]
cur += batch_size
if sort_namespace is not None:
sample_ids = [(idx, len(dataset[sort_namespace][idx])) for idx in sample_ids]
sample_ids = sorted(sample_ids, key=lambda x: x[1], reverse=True)
sorted_ids = [idx for idx, _ in sample_ids]
else:
sorted_ids = sample_ids
batch = {}
for namespace in dataset:
batch[namespace] = []
if namespace in self.wo_padding_namespace:
for id in sorted_ids:
batch[namespace].append(dataset[namespace][id])
else:
if namespace in vocab_dict:
padding_idx = self.vocab.get_padding_index(vocab_dict[namespace])
else:
padding_idx = 0
batch_namespace_len = [len(dataset[namespace][id]) for id in sorted_ids]
max_namespace_len = max(batch_namespace_len)
batch[namespace + '_lens'] = batch_namespace_len
batch[namespace + '_mask'] = []
if isinstance(dataset[namespace][0][0], list):
max_char_len = 0
for id in sorted_ids:
max_char_len = max(max_char_len,
max(len(item) for item in dataset[namespace][id]))
for id in sorted_ids:
padding_sent = []
mask = []
for item in dataset[namespace][id]:
padding_sent.append(item + [padding_idx] *
(max_char_len - len(item)))
mask.append([1] * len(item) + [0] * (max_char_len - len(item)))
padding_sent = padding_sent + [[padding_idx] * max_char_len] * (
max_namespace_len - len(dataset[namespace][id]))
mask = mask + [[0] * max_char_len
] * (max_namespace_len - len(dataset[namespace][id]))
batch[namespace].append(padding_sent)
batch[namespace + '_mask'].append(mask)
else:
for id in sorted_ids:
batch[namespace].append(
dataset[namespace][id] + [padding_idx] *
(max_namespace_len - len(dataset[namespace][id])))
batch[namespace +
'_mask'].append([1] * len(dataset[namespace][id]) + [0] *
(max_namespace_len - len(dataset[namespace][id])))
yield epoch, batch
def get_dataset_size(self, instance_name):
"""This function gets dataset size
Arguments:
instance_name {str} -- instance name
Returns:
int -- dataset size
"""
return self.instance_dict[instance_name]['size']
def set_wo_padding_namespace(self, wo_padding_namespace):
"""set_wo_padding_namespace sets without paddding namespace
Args:
wo_padding_namespace (list): without padding namespace
"""
self.wo_padding_namespace = wo_padding_namespace