|
from transformers import AutoTokenizer |
|
from datasets import load_dataset |
|
|
|
|
|
checkpoint = "google/gemma-2b" |
|
data_dir = "dataset_ro_small_v1/" |
|
|
|
seq_len = 2048 |
|
|
|
raw_datasets = load_dataset("json", data_dir=data_dir, split='train') |
|
raw_datasets = raw_datasets.remove_columns(['url', 'date_download', 'digest', |
|
'length', 'nlines', 'source_domain', |
|
'title', 'cc_segment', |
|
'original_nlines', |
|
'original_length', 'line_ids', |
|
'language', 'language_score']) |
|
raw_datasets = raw_datasets.rename_column('raw_content', 'text') |
|
raw_datasets = raw_datasets.train_test_split(test_size=0.1) |
|
|
|
print(raw_datasets) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
|
def tokenize_fn(examples): |
|
return tokenizer(examples['text'], |
|
max_length=seq_len, |
|
return_overflowing_tokens=True, |
|
truncation=True) |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenized_datasets = raw_datasets.map( |
|
tokenize_fn, |
|
batched=True, |
|
remove_columns=raw_datasets['train'].column_names |
|
) |
|
|
|
tokenized_datasets.save_to_disk(f'tokenized_{data_dir}') |
|
|