Spaces:
Paused
Paused
File size: 5,547 Bytes
ee6e328 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from collections import Counter
import datasets
import transformers
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from transformers.utils import logging
logging.set_verbosity_info()
TOKENIZER_CLASSES = {
name: (getattr(transformers, name), getattr(transformers, name + "Fast")) for name in SLOW_TO_FAST_CONVERTERS
}
dataset = datasets.load_dataset("xnli", split="test+validation")
total = 0
perfect = 0
imperfect = 0
wrong = 0
def check_diff(spm_diff, tok_diff, slow, fast):
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
elif len(spm_diff) == len(tok_diff) and fast.decode(spm_diff) == fast.decode(tok_diff):
# Second order OK
# Barrich -> Barr + ich vs Bar + rich
return True
spm_reencoded = slow.encode(slow.decode(spm_diff))
tok_reencoded = fast.encode(fast.decode(spm_diff))
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
# Type 3 error.
# Snehagatha ->
# Sne, h, aga, th, a
# Sne, ha, gat, ha
# Encoding the wrong with sp does not even recover what spm gave us
# It fits tokenizer however...
return True
return False
def check_LTR_mark(line, idx, fast):
enc = fast.encode_plus(line)[0]
offsets = enc.offsets
curr, prev = offsets[idx], offsets[idx - 1]
if curr is not None and line[curr[0] : curr[1]] == "\u200f":
return True
if prev is not None and line[prev[0] : prev[1]] == "\u200f":
return True
def check_details(line, spm_ids, tok_ids, slow, fast):
# Encoding can be the same with same result AAA -> A + AA vs AA + A
# We can check that we use at least exactly the same number of tokens.
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
if spm_id != tok_id:
break
first = i
for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
if spm_id != tok_id:
break
last = len(spm_ids) - i
spm_diff = spm_ids[first:last]
tok_diff = tok_ids[first:last]
if check_diff(spm_diff, tok_diff, slow, fast):
return True
if check_LTR_mark(line, first, fast):
return True
if last - first > 5:
# We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last])
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
min_width = 3
for i in range(last - first - min_width):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
possible_matches = [
k
for k in range(last - first - min_width)
if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
]
for j in possible_matches:
if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
line,
spm_ids[first + i : last],
tok_ids[first + j : last],
slow,
fast,
):
return True
print(f"Spm: {[fast.decode([spm_ids[i]]) for i in range(first, last)]}")
try:
print(f"Tok: {[fast.decode([tok_ids[i]]) for i in range(first, last)]}")
except Exception:
pass
ok_start = fast.decode(spm_ids[:first])
ok_end = fast.decode(spm_ids[last:])
wrong = fast.decode(spm_ids[first:last])
print()
print(wrong)
return False
def test_string(slow, fast, text):
global perfect
global imperfect
global wrong
global total
slow_ids = slow.encode(text)
fast_ids = fast.encode(text)
skip_assert = False
total += 1
if slow_ids != fast_ids:
if check_details(text, slow_ids, fast_ids, slow, fast):
skip_assert = True
imperfect += 1
else:
wrong += 1
else:
perfect += 1
if total % 10000 == 0:
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
if skip_assert:
return
assert (
slow_ids == fast_ids
), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}"
def test_tokenizer(slow, fast):
global batch_total
for i in range(len(dataset)):
# premise, all languages
for text in dataset[i]["premise"].values():
test_string(slow, fast, text)
# hypothesis, all languages
for text in dataset[i]["hypothesis"]["translation"]:
test_string(slow, fast, text)
if __name__ == "__main__":
for name, (slow_class, fast_class) in TOKENIZER_CLASSES.items():
checkpoint_names = list(slow_class.max_model_input_sizes.keys())
for checkpoint in checkpoint_names:
imperfect = 0
perfect = 0
wrong = 0
total = 0
print(f"========================== Checking {name}: {checkpoint} ==========================")
slow = slow_class.from_pretrained(checkpoint, force_download=True)
fast = fast_class.from_pretrained(checkpoint, force_download=True)
test_tokenizer(slow, fast)
print(f"Accuracy {perfect * 100 / total:.2f}")
|