Spaces:
Paused
Paused
"""Prepares the datasets for calibration. Original code gently shared by TheBloke""" | |
import os | |
from abc import ABC | |
import time | |
from typing import Dict, List, Optional | |
from datasets import load_dataset, Dataset | |
from transformers import PreTrainedTokenizerBase | |
class CalibrationDataset(ABC): | |
tokenizer: Optional[PreTrainedTokenizerBase] = None | |
num_samples: int = 128 | |
seqlen: int = 4096 | |
dataset_config: dict | |
dataset: str | |
dataset_name: str | |
# Defines the field to extract from the HF dataset | |
# If specified, just this field will be returned, and no transformation will be done. | |
dataset_field: Optional[str] = None | |
# Define the default parameters for a dataset which requires a transformation | |
# Only used if dataset_field is None. | |
# The fields to extract from the original dataset | |
transform_fields: List[str] = [] | |
# A format string describing how the fields should be joined | |
# Can use {field1}, {field2}, etc. as placeholders for the field names | |
# Or can use actual names, eg "{input} {output}" | |
transform_join: str = "{field1} {field2}" | |
# Optional override for the dataset URL | |
# By default this is automatically derived from the dataset name and config | |
dataset_url: Optional[str] = None | |
data: Optional[Dataset] = None | |
samples: List[str] = [] | |
tokenized_samples: List[Dict[str, str]] = {} | |
randomize: bool = False | |
randomize_seed: int = 42 | |
def __init__( | |
self, | |
num_samples: int = 128, | |
seqlen: int = 4096, | |
tokenizer: Optional[PreTrainedTokenizerBase] = None | |
): | |
self.num_samples = num_samples | |
self.seqlen = seqlen | |
self.tokenizer = tokenizer | |
def get_dataset(cls, dataset_name, **kwargs): | |
for subclass in cls.__subclasses__(): | |
if hasattr(subclass, "dataset") and subclass.dataset == dataset_name: | |
return subclass(**kwargs) | |
raise ValueError(f"No dataset class found for name: {dataset_name}") | |
def tokenize_dataset(self, samples: Optional[List[str]] = None) -> List[Dict[str, int]]: | |
""" | |
Tokenize the dataset and return a list of tokens of `seqlen` length | |
First tokenize the List[str] of samples, as a batch. | |
Then flatten the batch, and split it into `num_samples` rows of `seqlen` length. | |
""" | |
if not self.tokenizer: | |
raise ValueError("No tokenizer provided to tokenize_dataset()") | |
else: | |
if not samples: | |
if not self.samples: | |
self.get_samples() | |
samples = self.samples | |
print(f"Tokenizing {self.dataset_name} of length {len(samples)}") | |
start_time = time.time() | |
# Tokenize the list of samples. We don't use return_tensors="pt", | |
# as that requires the samples to be the same length, or padding to be used. | |
tokenized = self.tokenizer(samples) | |
# Output of tokenizer will be: | |
# {"input_ids": [[1,2,3], [4,5], [6,7]], "attention_mask": [[1,1,1], [1,1], [1,1]]} | |
# Flatten that so as to concatenate the samples into a single input_mask and attention_mask | |
flattened = { | |
key: [ | |
item for sublist in value | |
for item in sublist | |
] | |
for key, value in tokenized.items() | |
} | |
print( | |
f"Tokenized length: {len(flattened['input_ids'])} tokens." | |
) | |
# Slice our single input_mask list into num_samples samples of seqlen length | |
tokenized_samples = [] | |
for i in range(0, self.num_samples * self.seqlen, self.seqlen): | |
if i + self.seqlen >= len(flattened["input_ids"]): | |
break | |
sample = { | |
"input_ids": flattened["input_ids"][i:i + self.seqlen], | |
"attention_mask": flattened["attention_mask"][i:i + self.seqlen] | |
} | |
tokenized_samples.append(sample) | |
print( | |
f"Return {len(tokenized_samples)} samples of {self.seqlen} length. " | |
f"Time taken: {time.time() - start_time:.2f}s." | |
) | |
self.tokenized_samples = tokenized_samples | |
return self.tokenized_samples | |
def get_hf_dataset( | |
self, | |
path: str, | |
limit: Optional[int] = None, | |
**kwargs | |
) -> Dataset: | |
"""Load the Hugging Face dataset at `path`, using the provided kwargs.""" | |
print(f"Loading HF dataset {path} with params: {kwargs}") | |
data: Dataset = load_dataset(path=path, streaming=True, **kwargs) | |
return iter(data.shuffle().take(limit)) | |
def list_with_nls(samples: List[str]) -> List[str]: | |
""" | |
Return a List[str] with each sample ending in a newline. | |
Also filters the list by stripping, then removing any empty samples. | |
""" | |
return [ | |
x.rstrip() + '\n' | |
for x in samples | |
if x and len(x.strip()) > 0 | |
] | |
def get_samples(self) -> List[str]: | |
""" | |
Return a list of samples for the dataset. | |
If the subclass implements `dataset_field`, this is used to filter the HF Dataset. | |
Otherwise, the subclass must implement `process_samples()`, for custom filtering. | |
Samples are returned as a List[str], each ending in a newline. | |
""" | |
# Load HF dataset. Subclasses provide HF dataset details in `dataset_config` | |
if not self.data: | |
self.data = self.get_hf_dataset(**self.dataset_config, limit=self.num_samples*10) | |
if not self.samples: | |
if hasattr(self, "dataset_field") and self.dataset_field: | |
samples = [data[self.dataset_field] for data in self.data] | |
else: | |
try: | |
samples = self.process_samples() | |
except NotImplementedError: | |
raise ValueError( | |
f"No dataset field specified for class {self.__class__}, " | |
f"and process_samples() method not defined." | |
) | |
if self.randomize: | |
import random | |
random.seed(self.randomize_seed) | |
random.shuffle(samples) | |
self.samples = self.list_with_nls(samples) | |
return self.samples | |
def process_samples(self) -> List[str]: | |
if not self.transform_fields or not isinstance(self.transform_fields, list): | |
raise ValueError("transform_fields must be a List[str], defined in the subclass") | |
if not self.transform_join or not isinstance(self.transform_join, str): | |
raise ValueError("transform_fields must be a str defined in the subclass") | |
def transform_sample(sample): | |
field_values = {field: sample[field] for field in self.transform_fields} | |
# We support both: | |
# generic numbered fields: "{field1} {field2}" | |
# and named fields: "{input} {output}" | |
# Creating a combined dictionary to handle both specific field names and generic placeholders | |
combined_dict = {**field_values, **{f'field{i+1}': field for i, field in enumerate(field_values.values())}} | |
output = self.transform_join.format_map(combined_dict) | |
return {"output": output} | |
return self.data.map(transform_sample)["output"] | |
def generate_checksum(self) -> str: | |
# Create a sha256sum checksum of the joined samples | |
# Can be used to confirm that code updates haven't changed the output | |
import hashlib | |
samples = self.get_samples() | |
combined_samples = ''.join(samples) | |
checksum = hashlib.sha256(combined_samples.encode()).hexdigest() | |
return checksum | |
def get_dataset_url(cls) -> str: | |
"""Return the Hugging Face dataset URL for this dataset.""" | |
if hasattr(cls, "dataset_url") and cls.dataset_url: | |
return cls.dataset_url | |
else: | |
return "https://huggingface.co./datasets/{}/viewer/{}".format( | |
cls.dataset_config["path"], | |
cls.dataset_config.get("name", "") | |
) | |
class WikitextDataset(CalibrationDataset): | |
dataset = "wikitext" | |
dataset_field = "text" | |
dataset_config = { | |
"path": "wikitext", | |
"name": "wikitext-103-raw-v1", | |
"split": "train" | |
} | |
dataset_name = "Wikitext103 Full" | |
def process_samples(self) -> List[str]: | |
return [ | |
"\n" if len(item) == 0 else item | |
for item in self.data["text"] | |
] | |
class C4Dataset(CalibrationDataset): | |
dataset = "c4" | |
dataset_field = "text" | |
dataset_config = { | |
"path": "allenai/c4", | |
"data_files": { | |
"train": [ | |
"en/c4-train.00000-of-01024.json.gz", | |
"en/c4-train.00001-of-01024.json.gz", | |
"en/c4-train.00002-of-01024.json.gz", | |
"en/c4-train.00003-of-01024.json.gz", | |
"en/c4-train.00004-of-01024.json.gz", | |
"en/c4-train.00005-of-01024.json.gz", | |
"en/c4-train.00006-of-01024.json.gz", | |
"en/c4-train.00007-of-01024.json.gz", | |
"en/c4-train.00008-of-01024.json.gz", | |
"en/c4-train.00009-of-01024.json.gz", | |
"en/c4-train.00010-of-01024.json.gz", | |
"en/c4-train.00011-of-01024.json.gz", | |
"en/c4-train.00012-of-01024.json.gz", | |
"en/c4-train.00013-of-01024.json.gz", | |
"en/c4-train.00014-of-01024.json.gz", | |
"en/c4-train.00015-of-01024.json.gz", | |
"en/c4-train.00016-of-01024.json.gz", | |
"en/c4-train.00017-of-01024.json.gz", | |
"en/c4-train.00018-of-01024.json.gz", | |
"en/c4-train.00019-of-01024.json.gz", | |
], | |
}, | |
"split": "train" | |
} | |
dataset_name = "C4" | |
class CodeDataset(CalibrationDataset): | |
dataset = "code" | |
dataset_field = "content" | |
dataset_config = { | |
"path": "bigcode/the-stack", | |
"split": "train" | |
} | |
dataset_name = "The Stack" | |
def validate_dataset(dataset_name: str, **kwargs): | |
for cls in CalibrationDataset.__subclasses__(): | |
if hasattr(cls, "dataset") and cls.dataset == dataset_name: | |
return True | |
return False | |
# FIXME: a temp function put in for AutoAWQ, pending full refactor where it won't be necessary | |
def get_dataset_url(dataset_name: str): | |
for cls in CalibrationDataset.__subclasses__(): | |
if hasattr(cls, "dataset") and cls.dataset == dataset_name: | |
return cls.get_dataset_url() | |
raise ValueError(f"No dataset class found for name: {dataset_name}") | |
def get_dataset_name(dataset_name: str): | |
for cls in CalibrationDataset.__subclasses__(): | |
if hasattr(cls, "dataset") and cls.dataset == dataset_name: | |
return cls.dataset_name | |
raise ValueError(f"No dataset class found for name: {dataset_name}") | |
def test_datasets(datasets: Optional[List[str]] = None, checksum_only=False): | |
import sys | |
from transformers import AutoTokenizer | |
try: | |
failed = [] | |
for cls in CalibrationDataset.__subclasses__(): | |
if not hasattr(cls, "dataset") or not cls.dataset: | |
failed.append(cls.__name__) | |
if failed: | |
print(f"The following classes have no 'dataset' attribute: {failed}") | |
sys.exit(-1) | |
else: | |
print()(f"All classes have 'dataset' attribute.") | |
print(f"Enumerating CalibrationDataset classes") | |
classes = CalibrationDataset.__subclasses__() | |
dataset_names = [ | |
cls.dataset | |
for cls in classes | |
if cls.dataset and (not datasets or cls.dataset in datasets) | |
] | |
print(f"Found {len(classes)} total dataset classes: {[c.dataset for c in classes]}") | |
if datasets: | |
print(f"Will test {len(dataset_names)} datasets: {dataset_names}") | |
print(f"Starting test: loading Llama-2 tokenizer") | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True) | |
for name in dataset_names: | |
print(f"{name} test: loading dataset.") | |
dataset = CalibrationDataset.get_dataset(name, tokenizer=tokenizer) | |
if not checksum_only: | |
print(f"{name} test: running tokenize_dataset.") | |
toks = dataset.tokenize_dataset() | |
print(f"{name} test: getting dataset_url.") | |
url = dataset.get_dataset_url() | |
print(f"{name} - randomized? {dataset.randomize}") | |
print( | |
f"{name} - result: cls.data: length: {len(dataset.data)}, " | |
f"first row length: {len(dataset.data[0])}, " | |
f"first row data: '{dataset.data[0]}'." | |
) | |
print( | |
f"{name} - result: cls.samples: length: {len(dataset.samples)}, " | |
f"first row length: {len(dataset.samples[0])}, " | |
f"first row sample: '{dataset.samples[0]}'." | |
) | |
print( | |
f"{name} - result: tokenize_dataset result: length: {len(toks)}, " | |
f"length first row input_ids: {len(toks[0]['input_ids'])}." | |
) | |
print( | |
f"{name} - result: dataset_url: {url}" | |
) | |
checksum = dataset.generate_checksum() | |
print( | |
f"{name} - result: sha256 checksum: {checksum}" | |
) | |
except KeyboardInterrupt: | |
print("Test aborted") | |
except Exception as e: | |
print( | |
f"Received an exception during test. Test failed. " | |
f"Exception: {e}" | |
) | |
raise | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="Test calibration datasets") | |
parser.add_argument("--datasets", "-d", "-n", nargs="*", type=str, help="Dataset(s) to check; default is all") | |
parser.add_argument("--checksum_only", "-co", action="store_true", help="Only ouput the checksums for the datasets") | |
args = parser.parse_args() | |
test_datasets(args.datasets, checksum_only=args.checksum_only) | |