File size: 5,555 Bytes
f25fae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import difflib
import re
import unicodedata
from pathlib import Path
from tqdm.auto import tqdm
from datasets import load_dataset, Dataset
import jieba


def tokenize(text):
    """Tokenize Chinese text using Jieba."""
    tokens = jieba.lcut(text)
    return tokens


def get_ngrams(tokens, n):
    """Generate n-grams from tokens."""
    return set(zip(*[tokens[i:] for i in range(n)]))


def retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, ngram_len):
    """Find contaminated samples based on n-grams."""
    new_batch = {"completion": [], "ngram": [], "bench_name": [], "bench_text": []}
    for completion in batch["completion"]:
        tokens = tokenize(completion)
        ngrams = get_ngrams(tokens, ngram_len)
        for ngram in ngrams:
            if ngram in eval_ngrams:
                idx = eval_ngrams[ngram]
                new_batch["completion"].append(completion)
                new_batch["ngram"].append(ngram)
                new_batch["bench_name"].append(eval_datasets[idx])
                new_batch["bench_text"].append(eval_texts[idx])
                break
    return new_batch


def diff_strings(string1, string2):
    """Find matching parts between two strings."""
    matcher = difflib.SequenceMatcher(None, string1.lower(), string2.lower(), autojunk=False)
    matching_blocks = matcher.get_matching_blocks()
    matches = []
    for block in matching_blocks:
        start_a, start_b, length = block
        if length > 5:
            match = string1[start_a:start_a + length]
            matches.append(match)
    return matches


def add_match_stats(example):
    gen_text = " ".join(tokenize(example["completion"]))
    bench_text = " ".join(tokenize(example["bench_text"]))
    matching_parts = diff_strings(gen_text, bench_text)
    match = " ".join("".join(matching_parts).split())
    example["diff"] = matching_parts
    example["diff_ratio"] = len(match) / len(bench_text) if len(bench_text) > 0 else 0
    example["diff_length"] = len(match)
    example["longest_diff_part"] = max(matching_parts, key=len, default="")
    example["longest_diff_part_length"] = len(example["longest_diff_part"])
    return example


def main(args):
    # Load the evaluation data to build n-grams index
    eval_ngrams, eval_datasets, eval_texts = {}, [], []
    eval_data = load_dataset(args.eval_dataset, split="train")
    for example in tqdm(eval_data):
        tokens = tokenize(example["text"])
        ngrams = get_ngrams(tokens, args.ngram_length)
        if ngrams:
            idx = len(eval_texts)
            eval_ngrams.update(zip(ngrams, [idx] * len(ngrams)))
            eval_datasets.append(example.get("task_name", "unknown"))
            eval_texts.append(example["text"])

    train_dataset_path = Path(args.train_dataset)
    if train_dataset_path.exists() and train_dataset_path.suffix in [".json", ".csv"]:
        if train_dataset_path.suffix == ".json":
            train_data = Dataset.from_json(args.train_dataset)
        elif train_dataset_path.suffix == ".csv":
            train_data = Dataset.from_csv(args.train_dataset)
    else:
        train_data = load_dataset(args.train_dataset, split="train")

    contamination_report = train_data.map(
        lambda batch: retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, args.ngram_length),
        batched=True, batch_size=1000, num_proc=args.num_proc, remove_columns=train_data.column_names
    )

    contamination_report = contamination_report.map(
        lambda example: add_match_stats(example), num_proc=args.num_proc
    )

    contamination_report.push_to_hub(args.report_dataset_name, private=args.private)

    contamination_report = contamination_report.filter(lambda x: x["diff_ratio"] > args.diff_threshold)

    if args.save_decontaminated:
        contaminated_completions = set(contamination_report["completion"])
        filtered_data = train_data.filter(lambda x: x["completion"] not in contaminated_completions)
        filtered_data.push_to_hub(args.decontaminated_dataset_name, private=args.private)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate a decontamination report for a dataset.")
    parser.add_argument("--eval_dataset", type=str,
                        default="HuggingFaceTB/phi2_eval_data_for_decontamination",
                        help="Name of the dataset with benchmark samples to use for decontamination.")
    parser.add_argument("--train_dataset", type=str, required=True,
                        help="Path or name of the training dataset to process.")
    parser.add_argument("--report_dataset_name", type=str, required=True,
                        help="Name for the output dataset with decontamination report.")
    parser.add_argument("--decontaminated_dataset_name", type=str, help="Name for the decontaminated dataset.")
    parser.add_argument("--private", action='store_true', help="Whether to make the output dataset private.")
    parser.add_argument("--ngram_length", type=int, default=10, help="Length of the n-grams to consider.")
    parser.add_argument("--diff_threshold", type=float, default=0.5,
                        help="Threshold for filtering based on difference ratio.")
    parser.add_argument("--num_proc", type=int, default=90, help="Number of processes to use for map operations.")
    parser.add_argument("--save_decontaminated", action='store_true',
                        help="Whether to save the decontaminated dataset.")

    args = parser.parse_args()
    main(args)