from transformers import AutoTokenizer from datasets import load_dataset checkpoint = "google/gemma-2b" data_dir = "dataset_ro_small_v1/" seq_len = 2048 raw_datasets = load_dataset("json", data_dir=data_dir, split='train') raw_datasets = raw_datasets.remove_columns(['url', 'date_download', 'digest', 'length', 'nlines', 'source_domain', 'title', 'cc_segment', 'original_nlines', 'original_length', 'line_ids', 'language', 'language_score']) raw_datasets = raw_datasets.rename_column('raw_content', 'text') raw_datasets = raw_datasets.train_test_split(test_size=0.1) print(raw_datasets) # load tokenizer from checkpoint tokenizer = AutoTokenizer.from_pretrained(checkpoint) def tokenize_fn(examples): return tokenizer(examples['text'], max_length=seq_len, return_overflowing_tokens=True, truncation=True) tokenizer.pad_token = tokenizer.eos_token tokenized_datasets = raw_datasets.map( tokenize_fn, batched=True, remove_columns=raw_datasets['train'].column_names ) tokenized_datasets.save_to_disk(f'tokenized_{data_dir}')