|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from numpy.core.fromnumeric import argsort |
|
from text_encoder import SubwordTextEncoder |
|
import tokenizer |
|
import tempfile |
|
import argparse |
|
from transformers import BertTokenizer |
|
import random |
|
import math |
|
import numpy as np |
|
def merge_output_file_with_bert_vocab(output_filename, bert_vocab, temp_path): |
|
writer = open(output_filename, 'w', encoding='utf-8') |
|
_set = set() |
|
with open(bert_vocab, 'r', encoding='utf-8') as reader: |
|
for line in reader: |
|
writer.write(line) |
|
_set.add(line.strip()) |
|
print(temp_path) |
|
with open(temp_path, 'r', encoding='utf-8') as reader: |
|
for line in reader: |
|
if line.strip() not in _set: |
|
writer.write(line) |
|
|
|
writer.close() |
|
|
|
def build_target_size_vocab(token_counts, reserved_tokens, target_size): |
|
min_val = 1 |
|
max_val = len(token_counts) // (target_size ** 0.5) |
|
encoder = SubwordTextEncoder.build_to_target_size(target_size,token_counts,min_val, max_val, num_iterations=5, |
|
reserved_tokens=reserved_tokens, max_subtoken_length=None) |
|
fd, temp_vocab = tempfile.mkstemp() |
|
encoder.store_to_file(temp_vocab, add_single_quotes=False) |
|
return encoder, temp_vocab |
|
|
|
def compute_language_model(documents, vocab_file): |
|
all_tokens = 0 |
|
tokenized_documents = [] |
|
bert_tokenizer = BertTokenizer(vocab_file ,do_lower_case = True) |
|
words = bert_tokenizer.vocab |
|
for word in words.keys(): |
|
words[word] = 0 |
|
for doc in documents: |
|
tokens = bert_tokenizer.tokenize(doc) |
|
all_tokens += len(tokens) |
|
for token in tokens: |
|
words[token] +=1 |
|
tokenized_documents.append(tokens) |
|
for word in words.keys(): |
|
words[word] /= all_tokens |
|
probs = [] |
|
for doc in tokenized_documents: |
|
p = 0.0 |
|
for token in doc: |
|
p += math.log(words[token]) |
|
probs.append(p) |
|
|
|
return np.mean(probs) |
|
|
|
def vocab_extend(corpus, raw_vocab, output_filename, interval=10000 , threshold = 0.01): |
|
""" |
|
@description : The function to get the incremental vocabulary for |
|
|
|
@param : |
|
|
|
@Returns : |
|
|
|
""" |
|
|
|
|
|
documents = [] |
|
for line in open(corpus, "r",encoding='utf-8'): |
|
line = line.replace('\n','') |
|
if len(line) < 5: |
|
continue |
|
documents.append(line) |
|
print("docunments: "+str(len(documents))) |
|
token_counts = tokenizer.corpus_token_counts( |
|
corpus, corpus_max_lines = 4400000, |
|
split_on_newlines = True, additional_chars="", do_lower_case=True) |
|
lines = open(raw_vocab, 'r', encoding='utf-8').readlines() |
|
lines = [s.strip() for s in lines if len(s) > 0] |
|
reserved_tokens = lines |
|
random.shuffle(documents) |
|
origin_size = (len(reserved_tokens) // interval) * interval |
|
pre_lm = compute_language_model(documents, raw_vocab) |
|
print("origin_size: " + str(origin_size)) |
|
print("pre_lm: "+ str(pre_lm)) |
|
target_size = origin_size |
|
while True: |
|
target_size = target_size + interval |
|
_, temp_vocab = build_target_size_vocab(token_counts, reserved_tokens, target_size) |
|
now_lm = compute_language_model(documents, temp_vocab) |
|
print('now_lm: '+ str(now_lm)) |
|
delta = (pre_lm - now_lm)/pre_lm |
|
print('delta: ' + str(delta)) |
|
if delta <= threshold: |
|
merge_output_file_with_bert_vocab(output_filename, raw_vocab, temp_vocab) |
|
break |
|
pre_lm = now_lm |
|
|
|
|
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--corpus", default=None, type=str, required=True, |
|
help="the file of the corpus to train the vocabulary.") |
|
parser.add_argument("--raw_vocab", default=None, type=str, required=True, |
|
help="the path to the file of the origin vocabulary") |
|
parser.add_argument("--output_file", default=None, type=str, required=True, |
|
help="the output file of the final vocabulary") |
|
parser.add_argument('--interval', type=int, default=10000, |
|
help="The interval of the vocabulary size.") |
|
parser.add_argument('--threshold', type=int, default=10000, |
|
help="The final threhold of the P(D)'s increase") |
|
args = parser.parse_args() |
|
return args |
|
|
|
def main(): |
|
args = get_args() |
|
vocab_extend(args.corpus, args.raw_vocab, args.output_file, args.interval, args.threshold) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|