tangled-llama-33m-32k-instruct-v0.1 / scripts /prepare_contrain_dataset.py
mtasic85's picture
prepare dataset
992b12b
raw
history blame
3.9 kB
import gc
from datasets import load_dataset
from litdata import optimize, TokensLoader
from litgpt.tokenizer import Tokenizer
from functools import partial
def batch_iterator(name=None):
if name in (None, 'Replete-AI/Everything_Instruct_Multilingual'):
dataset = load_dataset('Replete-AI/Everything_Instruct_Multilingual', split='train')
for row in dataset:
text = []
if row['instruction']:
text.append(
'<|im_start|>system\n'
f"{row['instruction']}<|im_end|>"
)
if row['input']:
text.append(
'<|im_start|>user\n'
f"{row['input']}<|im_end|>"
)
if row['output']:
text.append(
'<|im_start|>assistant\n'
f"{row['output']}<|im_end|>"
)
text = '\n'.join(text) + '\n'
yield text
break
del dataset
gc.collect()
if name in (None, 'HuggingFaceH4/ultrachat_200k'):
dataset = load_dataset('HuggingFaceH4/ultrachat_200k', split='train_sft')
for row in dataset:
text = [
f"<|im_start|>{n['role']}\n{n['content']}<|im_end|>"
for n in row['messages']
]
text = '\n'.join(text) + '\n'
yield text
break
del dataset
gc.collect()
if name in (None, 'HuggingFaceH4/no_robots'):
dataset = load_dataset('HuggingFaceH4/no_robots', split='train')
for row in dataset:
text = [
f"<|im_start|>{n['role']}\n{n['content']}<|im_end|>"
for n in row['messages']
]
text = '\n'.join(text) + '\n'
yield text
break
del dataset
gc.collect()
if name in (None, 'datatab/ultrafeedback_binarized_serbian'):
dataset = load_dataset('datatab/ultrafeedback_binarized_serbian', split='train_sft')
for row in dataset:
text = [
f"<|im_start|>{n['role']}\n{n['content']}<|im_end|>"
for n in row['chosen']
]
text = '\n'.join(text) + '\n'
yield text
break
del dataset
gc.collect()
if name in (None, 'datatab/alpaca-cleaned-serbian-full'):
dataset = load_dataset('datatab/alpaca-cleaned-serbian-full', split='train')
for row in dataset:
text = []
if row['instruction']:
text.append(
'<|im_start|>system\n'
f"{row['instruction']}<|im_end|>"
)
if row['input']:
text.append(
'<|im_start|>user\n'
f"{row['input']}<|im_end|>"
)
if row['output']:
text.append(
'<|im_start|>assistant\n'
f"{row['output']}<|im_end|>"
)
text = '\n'.join(text) + '\n'
yield text
break
del dataset
gc.collect()
def tokenize_fn(dataset_name, tokenizer=None):
for text in batch_iterator(dataset_name):
text_ids = tokenizer.encode(text, bos=False, eos=True)
yield text_ids
datasets_names = [
'Replete-AI/Everything_Instruct_Multilingual',
'HuggingFaceH4/ultrachat_200k',
'HuggingFaceH4/no_robots',
'datatab/ultrafeedback_binarized_serbian',
'datatab/alpaca-cleaned-serbian-full',
]
outputs = optimize(
fn=partial(tokenize_fn, tokenizer=Tokenizer('..')),
inputs=datasets_names,
output_dir='../data/',
# Number of tokens to store by chunks. This is roughly 64MB of tokens per chunk.
chunk_size=((32768 + 1) * 500),
num_workers=16,
)