|
import argparse |
|
import os |
|
from pathlib import Path |
|
|
|
import ftfy |
|
import tensorflow as tf |
|
from lm_dataformat import Reader |
|
from tokenizers import Tokenizer |
|
from transformers import GPT2TokenizerFast |
|
from tqdm import tqdm |
|
import logging |
|
from multiprocessing import Pool, cpu_count |
|
from itertools import repeat |
|
import re |
|
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--input_dir", type=str, help="Path to where your files are located. Files ending in .zst are " |
|
"treated as archives, all others as raw text.") |
|
parser.add_argument("--files_per", type=int, default=100000, help="Text files per tfrecord") |
|
parser.add_argument("--name", type=str, default="openwebtext", |
|
help="Name of output files will be name_i.tfrecords where i is the number of the file") |
|
parser.add_argument("--output_dir", type=str, default="./tfrecords", help="Where to put tfrecords") |
|
parser.add_argument("--encoder_path", type=str, |
|
help="Path to encoder files, or leave unspecified to use GPT2 tokenizer") |
|
parser.add_argument("--minimum_size", type=int, default=100, help="Minimum size a document has to be to be included") |
|
parser.add_argument("--ftfy", action="store_false", help="normalize with ftfy") |
|
parser.add_argument("--wikitext-detokenize", action="store_false", help="use wikitext detokenizer") |
|
parser.add_argument("--separator", nargs="+", type=int, default=[50256], |
|
help="separator to place between files in chunk mode") |
|
parser.add_argument("--chunk_size", type=int, default=2048, help="How big a chunk should be in chunk mode. " |
|
"Should equal your model's context size") |
|
parser.add_argument("--write_dataset_config", action="store_true", help="Write the dataset config file on completion") |
|
parser.add_argument("--processes", type=int, default=0, help="Number of processes to use. Defaults to cpu count.") |
|
|
|
args = parser.parse_args() |
|
if not args.output_dir.endswith("/"): |
|
args.output_dir = args.output_dir + "/" |
|
if not args.input_dir.endswith("/"): |
|
args.input_dir = args.input_dir + "/" |
|
assert len(args.separator) == 1 |
|
|
|
|
|
def wikitext_detokenizer(string): |
|
|
|
string = string.replace("s '", "s'") |
|
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) |
|
|
|
string = string.replace(" @-@ ", "-") |
|
string = string.replace(" @,@ ", ",") |
|
string = string.replace(" @.@ ", ".") |
|
|
|
string = string.replace(" : ", ": ") |
|
string = string.replace(" ; ", "; ") |
|
string = string.replace(" . ", ". ") |
|
string = string.replace(" ! ", "! ") |
|
string = string.replace(" ? ", "? ") |
|
string = string.replace(" , ", ", ") |
|
|
|
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) |
|
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) |
|
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) |
|
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) |
|
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) |
|
|
|
string = string.replace("= = = =", "====") |
|
string = string.replace("= = =", "===") |
|
string = string.replace("= =", "==") |
|
string = string.replace(" " + chr(176) + " ", chr(176)) |
|
string = string.replace(" \n", "\n") |
|
string = string.replace("\n ", "\n") |
|
string = string.replace(" N ", " 1 ") |
|
string = string.replace(" 's", "'s") |
|
|
|
return string |
|
|
|
|
|
def _int64_feature(value): |
|
""" |
|
Returns an int64_list from a bool / enum / int / uint. |
|
""" |
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) |
|
|
|
|
|
def write_to_file(writer, data): |
|
""" |
|
writes data to tfrecord file |
|
""" |
|
feature = { |
|
"text": _int64_feature(data) |
|
} |
|
tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) |
|
writer.write(tf_example.SerializeToString()) |
|
|
|
|
|
def get_tokenizer(args): |
|
if args.encoder_path is None: |
|
return GPT2TokenizerFast.from_pretrained('gpt2') |
|
else: |
|
return Tokenizer.from_file(args.encoder_path) |
|
|
|
|
|
def split_list(l, n): |
|
|
|
return [l[i:i + n] for i in range(0, len(l), n)] |
|
|
|
|
|
def archive_to_tokens(f, encoder, args): |
|
|
|
|
|
reader = Reader(f) |
|
for doc in reader.stream_data(threaded=False): |
|
if args.ftfy: |
|
doc = ftfy.fix_text(doc, normalization='NFKC') |
|
if args.wikitext_detokenize: |
|
doc = wikitext_detokenizer(doc) |
|
doc = encoder.encode(doc) + args.separator |
|
yield split_list(doc, args.chunk_size) |
|
|
|
|
|
def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None): |
|
|
|
if files == None: |
|
return |
|
chunks = split_list(files, files_per) |
|
|
|
if len(chunks[-1]) != files_per and not write_remainder: |
|
remainder = chunks.pop(-1) |
|
else: |
|
remainder = None |
|
files_per = len(chunks[-1]) |
|
|
|
for files in chunks: |
|
fp = f"{output_dir}/{out_name}_{start_no}" |
|
if process_no is not None: |
|
fp += f"_{process_no}" |
|
fp += f"_{files_per}" |
|
fp += ".tfrecords" |
|
with tf.io.TFRecordWriter(fp) as writer: |
|
for f in files: |
|
write_to_file(writer, f) |
|
start_no += 1 |
|
return start_no, remainder |
|
|
|
|
|
def get_files(input_dir, filetypes=None): |
|
|
|
if filetypes == None: |
|
filetypes = ["jsonl.zst", ".txt", ".xz", ".tar.gz"] |
|
files = [list(Path(input_dir).glob(f"*{ft}")) for ft in filetypes] |
|
return [str(item) for sublist in files for item in sublist] |
|
|
|
|
|
def read_checkpoint(checkpoint_path, resume_from_checkpoint=True): |
|
|
|
if resume_from_checkpoint and os.path.isfile(checkpoint_path): |
|
try: |
|
resume_files_processed, tfrecord_count = [int(i) for i in open(checkpoint_path, "r").read().split(", ")] |
|
print(f"\nResuming from tfrecord no. {tfrecord_count} / file no. {resume_files_processed}") |
|
return resume_files_processed, tfrecord_count |
|
except: |
|
pass |
|
return 0, 0 |
|
|
|
|
|
def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_checkpoints=False, |
|
resume_from_checkpoint=False, display_pbar=False): |
|
|
|
files, args, process_no = params |
|
enc = get_tokenizer(args) |
|
|
|
|
|
discarded_files = 0 |
|
files_processed = 0 |
|
pbar = tqdm(desc=f"Writing TFRecord Files to {args.output_dir}. Parsed 0 input files. files_written ", |
|
disable=not display_pbar) |
|
checkpoint_path = f"{args.output_dir}/checkpoint.txt" |
|
resume_files_processed, tfrecord_count = read_checkpoint(checkpoint_path, resume_from_checkpoint) |
|
|
|
data_to_prepend = [] |
|
tokenized_files_array = [] |
|
|
|
for f in files: |
|
for tokenized_files in archive_to_tokens(f, enc, args): |
|
files_processed += 1 |
|
if files_processed < resume_files_processed: |
|
continue |
|
|
|
|
|
n_tokens = len(tokenized_files[-1]) |
|
if n_tokens < args.chunk_size: |
|
data = tokenized_files.pop(-1) |
|
if n_tokens >= args.minimum_size: |
|
data_to_prepend.extend(data) |
|
else: |
|
discarded_files += 1 |
|
|
|
if len(data_to_prepend) >= args.chunk_size: |
|
|
|
tokenized_files_array.append(data_to_prepend[:args.chunk_size]) |
|
data_to_prepend = data_to_prepend[args.chunk_size:] |
|
|
|
tokenized_files_array.extend(tokenized_files) |
|
|
|
if len(tokenized_files_array) >= args.files_per * write_every_n_files: |
|
_tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per, |
|
output_dir=args.output_dir, out_name=args.name, |
|
start_no=tfrecord_count, process_no=process_no) |
|
pbar.update(_tfrecord_count - tfrecord_count) |
|
pbar.set_description( |
|
f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ") |
|
tfrecord_count = _tfrecord_count |
|
tokenized_files_array = remainder if remainder is not None else [] |
|
with open(checkpoint_path, "w") as checkpoint_file: |
|
checkpoint_file.write(f"{files_processed}, {tfrecord_count}") |
|
|
|
if len(tokenized_files_array) >= args.files_per: |
|
_tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per, |
|
output_dir=args.output_dir, out_name=args.name, |
|
start_no=tfrecord_count, process_no=process_no) |
|
pbar.update(_tfrecord_count - tfrecord_count) |
|
pbar.set_description( |
|
f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ") |
|
tfrecord_count = _tfrecord_count |
|
with open(checkpoint_path, "w") as checkpoint_file: |
|
checkpoint_file.write(f"{files_processed}, {tfrecord_count}") |
|
else: |
|
remainder = tokenized_files_array |
|
|
|
if write_remainder: |
|
|
|
write_files(remainder, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name, |
|
start_no=tfrecord_count, write_remainder=True) |
|
|
|
successful_files = files_processed - discarded_files |
|
return {"discarded": discarded_files, "processed": files_processed, "successful": successful_files} |
|
|
|
|
|
def create_tfrecords_mp(files, args): |
|
files = split_list(files, len(files) // args.processes) |
|
with Pool(processes=args.processes) as pool: |
|
pbar = tqdm(pool.imap(create_tfrecords, zip(files, repeat(args), range(len(files))))) |
|
meta = {"discarded": 0, "processed": 0, "successful": 0} |
|
for results in pbar: |
|
pbar.update() |
|
for k, v in results.items(): |
|
meta[k] += v |
|
return meta |
|
|
|
|
|
if __name__ == "__main__": |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
files = get_files(args.input_dir) |
|
args.chunk_size += 1 |
|
|
|
if args.processes == 0: |
|
args.processes = cpu_count() |
|
if args.processes > 1: |
|
results = create_tfrecords_mp(files, args) |
|
else: |
|
results = create_tfrecords((files, args, 0), display_pbar=True) |
|
print(results) |
|
|