|
import random |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Dataset(): |
|
"""This class constructs dataset for multiple date file |
|
""" |
|
def __init__(self, name, instance_dict=dict()): |
|
"""This function initializes a dataset, |
|
define dataset name, this dataset contains multiple readers, as datafiles. |
|
|
|
Arguments: |
|
name {str} -- dataset name |
|
|
|
Keyword Arguments: |
|
instance_dict {dict} -- instance settings (default: {dict()}) |
|
""" |
|
|
|
self.dataset_name = name |
|
self.datasets = dict() |
|
self.instance_dict = dict(instance_dict) |
|
|
|
def add_instance(self, name, instance, reader, is_count=False, is_train=False): |
|
"""This function adds a instance to dataset |
|
|
|
Arguments: |
|
name {str} -- intance name |
|
instance {Instance} -- instance |
|
reader {DatasetReader} -- reader correspond to instance |
|
|
|
Keyword Arguments: |
|
is_count {bool} -- instance paticipates in counting or not (default: {False}) |
|
is_train {bool} -- instance is training data or not (default: {False}) |
|
""" |
|
|
|
self.instance_dict[name] = { |
|
'instance': instance, |
|
'reader': reader, |
|
'is_count': is_count, |
|
'is_train': is_train |
|
} |
|
|
|
def build_dataset(self, |
|
vocab, |
|
counter=None, |
|
min_count=dict(), |
|
pretrained_vocab=None, |
|
intersection_namespace=dict(), |
|
no_pad_namespace=list(), |
|
no_unk_namespace=list(), |
|
contain_pad_namespace=dict(), |
|
contain_unk_namespace=dict(), |
|
tokens_to_add=None): |
|
"""This function bulids dataset |
|
|
|
Arguments: |
|
vocab {Vocabulary} -- vocabulary |
|
|
|
Keyword Arguments: |
|
counter {dict} -- counter (default: {None}) |
|
min_count {dict} -- min count for each namespace (default: {dict()}) |
|
pretrained_vocab {dict} -- pretrained vocabulary (default: {None}) |
|
intersection_namespace {dict} -- intersection vocabulary namespace correspond to |
|
pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()}) |
|
no_pad_namespace {list} -- no padding vocabulary namespace (default: {list()}) |
|
no_unk_namespace {list} -- no unknown vocabulary namespace (default: {list()}) |
|
contain_pad_namespace {dict} -- contain padding token vocabulary namespace (default: {dict()}) |
|
contain_unk_namespace {dict} -- contain unknown token vocabulary namespace (default: {dict()}) |
|
tokens_to_add {dict} -- tokens need to be added to vocabulary (default: {None}) |
|
""" |
|
|
|
|
|
if counter is not None: |
|
for instance_name, instance_settting in self.instance_dict.items(): |
|
if instance_settting['is_count']: |
|
instance_settting['instance'].count_vocab_items(counter, |
|
instance_settting['reader']) |
|
|
|
|
|
vocab.extend_from_counter(counter, min_count, no_pad_namespace, no_unk_namespace, |
|
contain_pad_namespace, contain_unk_namespace) |
|
|
|
|
|
if tokens_to_add is not None: |
|
for namespace, tokens in tokens_to_add.items(): |
|
vocab.add_tokens_to_namespace(tokens, namespace) |
|
|
|
|
|
if pretrained_vocab is not None: |
|
vocab.extend_from_pretrained_vocab(pretrained_vocab, intersection_namespace, |
|
no_pad_namespace, no_unk_namespace, |
|
contain_pad_namespace, contain_unk_namespace) |
|
|
|
self.vocab = vocab |
|
|
|
for instance_name, instance_settting in self.instance_dict.items(): |
|
instance_settting['instance'].index(self.vocab, instance_settting['reader']) |
|
self.datasets[instance_name] = instance_settting['instance'].get_instance() |
|
self.instance_dict[instance_name]['size'] = instance_settting['instance'].get_size() |
|
self.instance_dict[instance_name]['vocab_dict'] = instance_settting[ |
|
'instance'].get_vocab_dict() |
|
|
|
logger.info("{} dataset size: {}.".format(instance_name, |
|
self.instance_dict[instance_name]['size'])) |
|
for key, seq_len in instance_settting['reader'].get_seq_lens().items(): |
|
logger.info("{} dataset's {}: max_len={}, min_len={}.".format( |
|
instance_name, key, max(seq_len), min(seq_len))) |
|
|
|
def get_batch(self, instance_name, batch_size, sort_namespace=None): |
|
"""get_batch gets batch data and padding |
|
|
|
Arguments: |
|
instance_name {str} -- instance name |
|
batch_size {int} -- batch size |
|
|
|
Keyword Arguments: |
|
sort_namespace {str} -- sort samples key, meanwhile calculate sequence length if not None, while keep None means that no sorting (default: {None}) |
|
|
|
Yields: |
|
int -- epoch |
|
dict -- batch data |
|
""" |
|
|
|
if instance_name not in self.instance_dict: |
|
logger.error('can not find instance name {} in datasets.'.format(instance_name)) |
|
return |
|
|
|
dataset = self.datasets[instance_name] |
|
|
|
if sort_namespace is not None and sort_namespace not in dataset: |
|
logger.error('can not find sort namespace {} in datasets instance {}.'.format( |
|
sort_namespace, instance_name)) |
|
|
|
size = self.instance_dict[instance_name]['size'] |
|
vocab_dict = self.instance_dict[instance_name]['vocab_dict'] |
|
ids = list(range(size)) |
|
if self.instance_dict[instance_name]['is_train']: |
|
random.shuffle(ids) |
|
epoch = 1 |
|
cur = 0 |
|
|
|
while True: |
|
if cur >= size: |
|
epoch += 1 |
|
if not self.instance_dict[instance_name]['is_train'] and epoch > 1: |
|
break |
|
random.shuffle(ids) |
|
cur = 0 |
|
|
|
sample_ids = ids[cur:cur + batch_size] |
|
cur += batch_size |
|
|
|
if sort_namespace is not None: |
|
sample_ids = [(idx, len(dataset[sort_namespace][idx])) for idx in sample_ids] |
|
sample_ids = sorted(sample_ids, key=lambda x: x[1], reverse=True) |
|
sorted_ids = [idx for idx, _ in sample_ids] |
|
else: |
|
sorted_ids = sample_ids |
|
|
|
batch = {} |
|
|
|
for namespace in dataset: |
|
batch[namespace] = [] |
|
|
|
if namespace in self.wo_padding_namespace: |
|
for id in sorted_ids: |
|
batch[namespace].append(dataset[namespace][id]) |
|
else: |
|
if namespace in vocab_dict: |
|
padding_idx = self.vocab.get_padding_index(vocab_dict[namespace]) |
|
else: |
|
padding_idx = 0 |
|
|
|
batch_namespace_len = [len(dataset[namespace][id]) for id in sorted_ids] |
|
max_namespace_len = max(batch_namespace_len) |
|
batch[namespace + '_lens'] = batch_namespace_len |
|
batch[namespace + '_mask'] = [] |
|
|
|
if isinstance(dataset[namespace][0][0], list): |
|
max_char_len = 0 |
|
for id in sorted_ids: |
|
max_char_len = max(max_char_len, |
|
max(len(item) for item in dataset[namespace][id])) |
|
for id in sorted_ids: |
|
padding_sent = [] |
|
mask = [] |
|
for item in dataset[namespace][id]: |
|
padding_sent.append(item + [padding_idx] * |
|
(max_char_len - len(item))) |
|
mask.append([1] * len(item) + [0] * (max_char_len - len(item))) |
|
padding_sent = padding_sent + [[padding_idx] * max_char_len] * ( |
|
max_namespace_len - len(dataset[namespace][id])) |
|
mask = mask + [[0] * max_char_len |
|
] * (max_namespace_len - len(dataset[namespace][id])) |
|
batch[namespace].append(padding_sent) |
|
batch[namespace + '_mask'].append(mask) |
|
else: |
|
for id in sorted_ids: |
|
batch[namespace].append( |
|
dataset[namespace][id] + [padding_idx] * |
|
(max_namespace_len - len(dataset[namespace][id]))) |
|
batch[namespace + |
|
'_mask'].append([1] * len(dataset[namespace][id]) + [0] * |
|
(max_namespace_len - len(dataset[namespace][id]))) |
|
|
|
yield epoch, batch |
|
|
|
def get_dataset_size(self, instance_name): |
|
"""This function gets dataset size |
|
|
|
Arguments: |
|
instance_name {str} -- instance name |
|
|
|
Returns: |
|
int -- dataset size |
|
""" |
|
|
|
return self.instance_dict[instance_name]['size'] |
|
|
|
def set_wo_padding_namespace(self, wo_padding_namespace): |
|
"""set_wo_padding_namespace sets without paddding namespace |
|
|
|
Args: |
|
wo_padding_namespace (list): without padding namespace |
|
""" |
|
|
|
self.wo_padding_namespace = wo_padding_namespace |
|
|