import gc from typing import Union, Optional, Iterator from datasets import load_dataset from litgpt.tokenizer import Tokenizer def _batch_text_iterator(path: str, name: Optional[str]=None, data_dir: Optional[str]=None, data_files: Optional[str]=None, keep_in_memory: bool=False, revision: Optional[str]=None, split: str='train', num_proc: Optional[int]=None, format: Optional[str]=None) -> Iterator[str]: assert isinstance(format, str) or callable(format), repr(format) dataset = load_dataset(path=path, name=name, data_dir=data_dir, data_files=data_files, keep_in_memory=keep_in_memory, revision=revision, split=split, trust_remote_code=True, num_proc=num_proc) if callable(format): for row in dataset: text = format(row) yield text else: for row in dataset: text = format.format(**row) yield text del dataset gc.collect() def batch_text_iterator(dataset_config: Union[list, dict]) -> Iterator[str]: assert isinstance(dataset_config, (dict, list)), dataset_config if isinstance(dataset_config, dict): for text in _batch_text_iterator(**dataset_config): yield text elif isinstance(dataset_config, list): for dc in dataset_config: for text in _batch_text_iterator(**dc): yield text def tokenize_text_fn(dataset_config: list, tokenizer: Optional[Tokenizer]=None, min_len: Optional[int]=None, max_len: Optional[int]=None) -> Iterator[str]: for text in batch_text_iterator(dataset_config): text_ids: list[int] = tokenizer.encode(text, bos=False, eos=True) if min_len is None: min_len = 0 if max_len is None: max_len = len(text_ids) if min_len <= len(text_ids) <= max_len: yield text_ids