from bidict import bidict import pickle import logging logger = logging.getLogger(__name__) class Vocabulary(): """This class maps strings to integers, which also allow many namespaces """ DEFAULT_PAD_TOKEN = '*@PAD@*' DEFAULT_UNK_TOKEN = '*@UNK@*' def __init__(self, counters=dict(), min_count=dict(), pretrained_vocab=dict(), intersection_namespace=dict(), no_pad_namespace=list(), no_unk_namespace=list(), contain_pad_namespace=dict(), contain_unk_namespace=dict()): """initialize vocabulary Keyword Arguments: counters {dict} -- multiple counter (default: {dict()}) min_count {dict} -- min count dict (default: {dict()}) pretrained_vocab {dict} -- pretrained vocabulary (default: {dict()}) intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()}) no_pad_namespace {list} -- no paddding namespace (default: {list()}) no_unk_namespace {list} -- no unknown namespace (default: {list()}) contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()}) contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()}) """ self.min_count = dict(min_count) self.intersection_namespace = dict(intersection_namespace) self.no_pad_namespace = set(no_pad_namespace) self.no_unk_namespace = set(no_unk_namespace) self.contain_pad_namespace = dict(contain_pad_namespace) self.contain_unk_namespace = dict(contain_unk_namespace) self.vocab = dict() self.extend_from_counter(counters, self.min_count, self.no_pad_namespace, self.no_unk_namespace) self.extend_from_pretrained_vocab(pretrained_vocab, self.intersection_namespace, self.no_pad_namespace, self.no_unk_namespace) logger.info("Initialize vocabulary successfully.") def extend_from_pretrained_vocab(self, pretrained_vocab, intersection_namespace=dict(), no_pad_namespace=list(), no_unk_namespace=list(), contain_pad_namespace=dict(), contain_unk_namespace=dict()): """extend vocabulary from pretrained vocab Arguments: pretrained_vocab {dict} -- pretrained vocabulary Keyword Arguments: intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()}) no_pad_namespace {list} -- no paddding namespace (default: {list()}) no_unk_namespace {list} -- no unknown namespace (default: {list()}) contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()}) contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()}) """ self.intersection_namespace.update(dict(intersection_namespace)) self.no_pad_namespace.update(set(no_pad_namespace)) self.no_unk_namespace.update(set(no_unk_namespace)) self.contain_pad_namespace.update(dict(contain_pad_namespace)) self.contain_unk_namespace.update(dict(contain_unk_namespace)) for namespace, vocab in pretrained_vocab.items(): self.__namespace_init(namespace) is_intersection = namespace in self.intersection_namespace intersection_vocab = self.vocab[ self.intersection_namespace[namespace]] if is_intersection else [] for key, value in vocab.items(): if not is_intersection or key in intersection_vocab: self.vocab[namespace][key] = value logger.info( "Vocabulay {} (size: {}) was constructed successfully from pretrained_vocab.". format(namespace, len(self.vocab[namespace]))) def extend_from_counter(self, counters, min_count=dict(), no_pad_namespace=list(), no_unk_namespace=list(), contain_pad_namespace=dict(), contain_unk_namespace=dict()): """extend vocabulary from counter Arguments: counters {dict} -- multiply counter Keyword Arguments: min_count {dict} -- min count dict (default: {dict()}) no_pad_namespace {list} -- no paddding namespace (default: {list()}) no_unk_namespace {list} -- no unknown namespace (default: {list()}) contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()}) contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()}) """ self.no_pad_namespace.update(set(no_pad_namespace)) self.no_unk_namespace.update(set(no_unk_namespace)) self.contain_pad_namespace.update(dict(contain_pad_namespace)) self.contain_unk_namespace.update(dict(contain_unk_namespace)) self.min_count.update(dict(min_count)) for namespace, counter in counters.items(): self.__namespace_init(namespace) for key in counter: minc = min_count[namespace] \ if min_count and namespace in min_count else 1 if counter[key] >= minc: self.vocab[namespace][key] = len(self.vocab[namespace]) logger.info("Vocabulay {} (size: {}) was constructed successfully from counter.".format( namespace, len(self.vocab[namespace]))) def add_tokens_to_namespace(self, tokens, namespace): """This function adds tokens to one namespace for extending vocabulary Arguments: tokens {list} -- token list namespace {str} -- namespace name """ if namespace not in self.vocab: self.__namespace_init(namespace) logger.error('Add Namespace {} into vocabulary.'.format(namespace)) for token in tokens: if token not in self.vocab[namespace]: self.vocab[namespace][token] = len(self.vocab[namespace]) def get_token_index(self, token, namespace): """This function gets token index in one namespace of vocabulary Arguments: token {str} -- token namespace {str} -- namespace name Raises: RuntimeError: namespace not exists Returns: int -- token index """ if token in self.vocab[namespace]: return self.vocab[namespace][token] if namespace not in self.no_unk_namespace: return self.get_unknown_index(namespace) logger.error("Can not find the index of {} from a no unknown token namespace {}.".format( token, namespace)) raise RuntimeError( "Can not find the index of {} from a no unknown token namespace {}.".format( token, namespace)) def get_token_from_index(self, index, namespace): """This function gets token using index in vocabulary Arguments: index {int} -- index namespace {str} -- namespace name Raises: RuntimeError: index out of range Returns: str -- token """ if index < len(self.vocab[namespace]): return self.vocab[namespace].inv[index] logger.error("The index {} is out of vocabulary {} range.".format(index, namespace)) raise RuntimeError("The index {} is out of vocabulary {} range.".format(index, namespace)) def get_vocab_size(self, namespace): """This function gets the size of one namespace in vocabulary Arguments: namespace {str} -- namespace name Returns: int -- vocabulary size """ return len(self.vocab[namespace]) def get_all_namespaces(self): """This function gets all namespaces Returns: list -- all namespaces vocabulary contained """ return set(self.vocab) def get_padding_index(self, namespace): """This function gets padding token index in one namespace of vocabulary Arguments: namespace {str} -- namespace name Raises: RuntimeError: no padding Returns: int -- padding index """ if namespace not in self.vocab: raise RuntimeError("Namespace {} doesn't exist.".format(namespace)) if namespace not in self.no_pad_namespace: if namespace not in self.contain_pad_namespace: return self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN] return self.vocab[namespace][self.contain_pad_namespace[namespace]] logger.error("Namespace {} doesn't has paddding token.".format(namespace)) raise RuntimeError("Namespace {} doesn't has paddding token.".format(namespace)) def get_unknown_index(self, namespace): """This function gets unknown token index in one namespace of vocabulary Arguments: namespace {str} -- namespace name Raises: RuntimeError: no unknown Returns: int -- unknown index """ if namespace not in self.vocab: raise RuntimeError("Namespace {} doesn't exist.".format(namespace)) if namespace not in self.no_unk_namespace: if namespace not in self.contain_unk_namespace: return self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN] return self.vocab[namespace][self.contain_unk_namespace[namespace]] logger.error("Namespace {} doesn't has unknown token.".format(namespace)) raise RuntimeError("Namespace {} doesn't has unknown token.".format(namespace)) def get_namespace_tokens(self, namesapce): """This function returns all tokens in one namespace Arguments: namesapce {str} -- namespce name Returns: dict_keys -- all tokens """ return self.vocab[namesapce] def save(self, file_path): """This function saves vocabulary into file Arguments: file_path {str} -- file path """ pickle.dump(self, open(file_path, 'wb')) @classmethod def load(cls, file_path): """This function loads vocabulary from file Arguments: file_path {str} -- file path Returns: Vocabulary -- vocabulary """ return pickle.load(open(file_path, 'rb'), encoding='utf-8') def __namespace_init(self, namespace): """This function initializes a namespace, adds pad and unk token to one namespace of vacabulary Arguments: namespace {str} -- namespace """ self.vocab[namespace] = bidict() if namespace not in self.no_pad_namespace and namespace not in self.contain_pad_namespace: self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN] = len(self.vocab[namespace]) if namespace not in self.no_unk_namespace and namespace not in self.contain_unk_namespace: self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN] = len(self.vocab[namespace])