File size: 5,547 Bytes
ee6e328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from collections import Counter
import datasets
import transformers
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS

from transformers.utils import logging

logging.set_verbosity_info()

TOKENIZER_CLASSES = {
    name: (getattr(transformers, name), getattr(transformers, name + "Fast")) for name in SLOW_TO_FAST_CONVERTERS
}

dataset = datasets.load_dataset("xnli", split="test+validation")

total = 0
perfect = 0
imperfect = 0
wrong = 0


def check_diff(spm_diff, tok_diff, slow, fast):
    if spm_diff == list(reversed(tok_diff)):
        # AAA -> AA+A vs A+AA case.
        return True
    elif len(spm_diff) == len(tok_diff) and fast.decode(spm_diff) == fast.decode(tok_diff):
        # Second order OK
        # Barrich -> Barr + ich vs Bar + rich
        return True
    spm_reencoded = slow.encode(slow.decode(spm_diff))
    tok_reencoded = fast.encode(fast.decode(spm_diff))
    if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
        # Type 3 error.
        # Snehagatha ->
        #       Sne, h, aga, th, a
        #       Sne, ha, gat, ha
        # Encoding the wrong with sp does not even recover what spm gave us
        # It fits tokenizer however...
        return True
    return False


def check_LTR_mark(line, idx, fast):
    enc = fast.encode_plus(line)[0]
    offsets = enc.offsets
    curr, prev = offsets[idx], offsets[idx - 1]
    if curr is not None and line[curr[0] : curr[1]] == "\u200f":
        return True
    if prev is not None and line[prev[0] : prev[1]] == "\u200f":
        return True


def check_details(line, spm_ids, tok_ids, slow, fast):
    # Encoding can be the same with same result AAA -> A + AA vs AA + A
    # We can check that we use at least exactly the same number of tokens.
    for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
        if spm_id != tok_id:
            break
    first = i
    for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
        if spm_id != tok_id:
            break
    last = len(spm_ids) - i

    spm_diff = spm_ids[first:last]
    tok_diff = tok_ids[first:last]

    if check_diff(spm_diff, tok_diff, slow, fast):
        return True

    if check_LTR_mark(line, first, fast):
        return True

    if last - first > 5:
        # We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
        spms = Counter(spm_ids[first:last])
        toks = Counter(tok_ids[first:last])

        removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
        min_width = 3
        for i in range(last - first - min_width):
            if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
                possible_matches = [
                    k
                    for k in range(last - first - min_width)
                    if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
                ]
                for j in possible_matches:
                    if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
                        line,
                        spm_ids[first + i : last],
                        tok_ids[first + j : last],
                        slow,
                        fast,
                    ):
                        return True

    print(f"Spm: {[fast.decode([spm_ids[i]]) for i in range(first, last)]}")
    try:
        print(f"Tok: {[fast.decode([tok_ids[i]]) for i in range(first, last)]}")
    except Exception:
        pass

    ok_start = fast.decode(spm_ids[:first])
    ok_end = fast.decode(spm_ids[last:])
    wrong = fast.decode(spm_ids[first:last])
    print()
    print(wrong)
    return False


def test_string(slow, fast, text):
    global perfect
    global imperfect
    global wrong
    global total

    slow_ids = slow.encode(text)
    fast_ids = fast.encode(text)

    skip_assert = False
    total += 1

    if slow_ids != fast_ids:
        if check_details(text, slow_ids, fast_ids, slow, fast):
            skip_assert = True
            imperfect += 1
        else:
            wrong += 1
    else:
        perfect += 1

    if total % 10000 == 0:
        print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")

    if skip_assert:
        return

    assert (
        slow_ids == fast_ids
    ), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}"


def test_tokenizer(slow, fast):
    global batch_total
    for i in range(len(dataset)):
        # premise, all languages
        for text in dataset[i]["premise"].values():
            test_string(slow, fast, text)

        # hypothesis, all languages
        for text in dataset[i]["hypothesis"]["translation"]:
            test_string(slow, fast, text)


if __name__ == "__main__":
    for name, (slow_class, fast_class) in TOKENIZER_CLASSES.items():
        checkpoint_names = list(slow_class.max_model_input_sizes.keys())
        for checkpoint in checkpoint_names:
            imperfect = 0
            perfect = 0
            wrong = 0
            total = 0

            print(f"========================== Checking {name}: {checkpoint} ==========================")
            slow = slow_class.from_pretrained(checkpoint, force_download=True)
            fast = fast_class.from_pretrained(checkpoint, force_download=True)
            test_tokenizer(slow, fast)
            print(f"Accuracy {perfect * 100 / total:.2f}")