"""Prepares the datasets for calibration. Original code gently shared by TheBloke""" import os from abc import ABC import time from typing import Dict, List, Optional from datasets import load_dataset, Dataset from transformers import PreTrainedTokenizerBase class CalibrationDataset(ABC): tokenizer: Optional[PreTrainedTokenizerBase] = None num_samples: int = 128 seqlen: int = 4096 dataset_config: dict dataset: str dataset_name: str dataset_limit: int = int(1e7) # Defines the field to extract from the HF dataset # If specified, just this field will be returned, and no transformation will be done. dataset_field: Optional[str] = None # Define the default parameters for a dataset which requires a transformation # Only used if dataset_field is None. # The fields to extract from the original dataset transform_fields: List[str] = [] # A format string describing how the fields should be joined # Can use {field1}, {field2}, etc. as placeholders for the field names # Or can use actual names, eg "{input} {output}" transform_join: str = "{field1} {field2}" # Optional override for the dataset URL # By default this is automatically derived from the dataset name and config dataset_url: Optional[str] = None data: Optional[Dataset] = None samples: List[str] = [] tokenized_samples: List[Dict[str, str]] = {} randomize: bool = False randomize_seed: int = 42 def __init__( self, num_samples: int = 128, seqlen: int = 4096, tokenizer: Optional[PreTrainedTokenizerBase] = None ): self.num_samples = num_samples self.seqlen = seqlen self.tokenizer = tokenizer @classmethod def get_dataset(cls, dataset_name, **kwargs): for subclass in cls.__subclasses__(): if hasattr(subclass, "dataset") and subclass.dataset == dataset_name: return subclass(**kwargs) raise ValueError(f"No dataset class found for name: {dataset_name}") def tokenize_dataset(self, samples: Optional[List[str]] = None) -> List[Dict[str, int]]: """ Tokenize the dataset and return a list of tokens of `seqlen` length First tokenize the List[str] of samples, as a batch. Then flatten the batch, and split it into `num_samples` rows of `seqlen` length. """ if not self.tokenizer: raise ValueError("No tokenizer provided to tokenize_dataset()") else: if not samples: if not self.samples: self.get_samples() samples = self.samples print(f"Tokenizing {self.dataset_name} of length {len(samples)}") start_time = time.time() # Tokenize the list of samples. We don't use return_tensors="pt", # as that requires the samples to be the same length, or padding to be used. tokenized = self.tokenizer(samples) # Output of tokenizer will be: # {"input_ids": [[1,2,3], [4,5], [6,7]], "attention_mask": [[1,1,1], [1,1], [1,1]]} # Flatten that so as to concatenate the samples into a single input_mask and attention_mask flattened = { key: [ item for sublist in value for item in sublist ] for key, value in tokenized.items() } print( f"Tokenized length: {len(flattened['input_ids'])} tokens." ) # Slice our single input_mask list into num_samples samples of seqlen length tokenized_samples = [] for i in range(0, self.num_samples * self.seqlen, self.seqlen): if i + self.seqlen >= len(flattened["input_ids"]): break sample = { "input_ids": flattened["input_ids"][i:i + self.seqlen], "attention_mask": flattened["attention_mask"][i:i + self.seqlen] } tokenized_samples.append(sample) print( f"Return {len(tokenized_samples)} samples of {self.seqlen} length. " f"Time taken: {time.time() - start_time:.2f}s." ) self.tokenized_samples = tokenized_samples return self.tokenized_samples def get_hf_dataset( self, path: str, limit: Optional[int] = None, **kwargs ) -> Dataset: """Load the Hugging Face dataset at `path`, using the provided kwargs.""" print(f"Loading HF dataset {path} with params: {kwargs}") data: Dataset = load_dataset(path=path, streaming=True, **kwargs) return data.shuffle().take(limit) @staticmethod def list_with_nls(samples: List[str]) -> List[str]: """ Return a List[str] with each sample ending in a newline. Also filters the list by stripping, then removing any empty samples. """ return [ x.rstrip() + '\n' for x in samples if x and len(x.strip()) > 0 ] def get_samples(self) -> List[str]: """ Return a list of samples for the dataset. If the subclass implements `dataset_field`, this is used to filter the HF Dataset. Otherwise, the subclass must implement `process_samples()`, for custom filtering. Samples are returned as a List[str], each ending in a newline. """ # Load HF dataset. Subclasses provide HF dataset details in `dataset_config` if not self.data: self.data = self.get_hf_dataset(**self.dataset_config, limit=self.dataset_limit) if not self.samples: if hasattr(self, "dataset_field") and self.dataset_field: samples = self.data[self.dataset_field] else: try: samples = self.process_samples() except NotImplementedError: raise ValueError( f"No dataset field specified for class {self.__class__}, " f"and process_samples() method not defined." ) if self.randomize: import random random.seed(self.randomize_seed) random.shuffle(samples) self.samples = self.list_with_nls(samples) return self.samples def process_samples(self) -> List[str]: if not self.transform_fields or not isinstance(self.transform_fields, list): raise ValueError("transform_fields must be a List[str], defined in the subclass") if not self.transform_join or not isinstance(self.transform_join, str): raise ValueError("transform_fields must be a str defined in the subclass") def transform_sample(sample): field_values = {field: sample[field] for field in self.transform_fields} # We support both: # generic numbered fields: "{field1} {field2}" # and named fields: "{input} {output}" # Creating a combined dictionary to handle both specific field names and generic placeholders combined_dict = {**field_values, **{f'field{i+1}': field for i, field in enumerate(field_values.values())}} output = self.transform_join.format_map(combined_dict) return {"output": output} return self.data.map(transform_sample)["output"] def generate_checksum(self) -> str: # Create a sha256sum checksum of the joined samples # Can be used to confirm that code updates haven't changed the output import hashlib samples = self.get_samples() combined_samples = ''.join(samples) checksum = hashlib.sha256(combined_samples.encode()).hexdigest() return checksum @classmethod def get_dataset_url(cls) -> str: """Return the Hugging Face dataset URL for this dataset.""" if hasattr(cls, "dataset_url") and cls.dataset_url: return cls.dataset_url else: return "https://huggingface.co./datasets/{}/viewer/{}".format( cls.dataset_config["path"], cls.dataset_config.get("name", "") ) class WikitextDataset(CalibrationDataset): dataset = "wikitext" dataset_field = "text" dataset_config = { "path": "wikitext", "name": "wikitext-103-raw-v1", "split": "train" } dataset_name = "Wikitext103 Full" # def process_samples(self) -> List[str]: # return [ # "\n" if len(item) == 0 else item # for item in self.data["text"] # ] class C4Dataset(CalibrationDataset): dataset = "c4" dataset_field = "text" dataset_config = { "path": "allenai/c4", "split": "train" } dataset_name = "C4" class CodeDataset(CalibrationDataset): dataset = "code" dataset_field = "content" dataset_config = { "path": "bigcode/the-stack", "split": "train" } dataset_name = "The Stack" def validate_dataset(dataset_name: str, **kwargs): for cls in CalibrationDataset.__subclasses__(): if hasattr(cls, "dataset") and cls.dataset == dataset_name: return True return False # FIXME: a temp function put in for AutoAWQ, pending full refactor where it won't be necessary def get_dataset_url(dataset_name: str): for cls in CalibrationDataset.__subclasses__(): if hasattr(cls, "dataset") and cls.dataset == dataset_name: return cls.get_dataset_url() raise ValueError(f"No dataset class found for name: {dataset_name}") def get_dataset_name(dataset_name: str): for cls in CalibrationDataset.__subclasses__(): if hasattr(cls, "dataset") and cls.dataset == dataset_name: return cls.dataset_name raise ValueError(f"No dataset class found for name: {dataset_name}") def test_datasets(datasets: Optional[List[str]] = None, checksum_only=False): import sys from transformers import AutoTokenizer try: failed = [] for cls in CalibrationDataset.__subclasses__(): if not hasattr(cls, "dataset") or not cls.dataset: failed.append(cls.__name__) if failed: print(f"The following classes have no 'dataset' attribute: {failed}") sys.exit(-1) else: print()(f"All classes have 'dataset' attribute.") print(f"Enumerating CalibrationDataset classes") classes = CalibrationDataset.__subclasses__() dataset_names = [ cls.dataset for cls in classes if cls.dataset and (not datasets or cls.dataset in datasets) ] print(f"Found {len(classes)} total dataset classes: {[c.dataset for c in classes]}") if datasets: print(f"Will test {len(dataset_names)} datasets: {dataset_names}") print(f"Starting test: loading Llama-2 tokenizer") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True) for name in dataset_names: print(f"{name} test: loading dataset.") dataset = CalibrationDataset.get_dataset(name, tokenizer=tokenizer) if not checksum_only: print(f"{name} test: running tokenize_dataset.") toks = dataset.tokenize_dataset() print(f"{name} test: getting dataset_url.") url = dataset.get_dataset_url() print(f"{name} - randomized? {dataset.randomize}") print( f"{name} - result: cls.data: length: {len(dataset.data)}, " f"first row length: {len(dataset.data[0])}, " f"first row data: '{dataset.data[0]}'." ) print( f"{name} - result: cls.samples: length: {len(dataset.samples)}, " f"first row length: {len(dataset.samples[0])}, " f"first row sample: '{dataset.samples[0]}'." ) print( f"{name} - result: tokenize_dataset result: length: {len(toks)}, " f"length first row input_ids: {len(toks[0]['input_ids'])}." ) print( f"{name} - result: dataset_url: {url}" ) checksum = dataset.generate_checksum() print( f"{name} - result: sha256 checksum: {checksum}" ) except KeyboardInterrupt: print("Test aborted") except Exception as e: print( f"Received an exception during test. Test failed. " f"Exception: {e}" ) raise if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Test calibration datasets") parser.add_argument("--datasets", "-d", "-n", nargs="*", type=str, help="Dataset(s) to check; default is all") parser.add_argument("--checksum_only", "-co", action="store_true", help="Only ouput the checksums for the datasets") args = parser.parse_args() test_datasets(args.datasets, checksum_only=args.checksum_only)