File size: 7,454 Bytes
d991264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
# Note that there are more actual tests, they're just not currently public :-)
from typing import Callable
import hypothesis
import hypothesis.strategies as st
import pytest
import tiktoken
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES
def test_simple():
enc = tiktoken.get_encoding("gpt2")
assert enc.encode("hello world") == [31373, 995]
assert enc.decode([31373, 995]) == "hello world"
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256]
enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("hello world") == [15339, 1917]
assert enc.decode([15339, 1917]) == "hello world"
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257]
for enc_name in tiktoken.list_encoding_names():
enc = tiktoken.get_encoding(enc_name)
for token in range(10_000):
assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token
def test_simple_repeated():
enc = tiktoken.get_encoding("gpt2")
assert enc.encode("0") == [15]
assert enc.encode("00") == [405]
assert enc.encode("000") == [830]
assert enc.encode("0000") == [2388]
assert enc.encode("00000") == [20483]
assert enc.encode("000000") == [10535]
assert enc.encode("0000000") == [24598]
assert enc.encode("00000000") == [8269]
assert enc.encode("000000000") == [10535, 830]
assert enc.encode("0000000000") == [8269, 405]
assert enc.encode("00000000000") == [8269, 830]
assert enc.encode("000000000000") == [8269, 2388]
assert enc.encode("0000000000000") == [8269, 20483]
assert enc.encode("00000000000000") == [8269, 10535]
assert enc.encode("000000000000000") == [8269, 24598]
assert enc.encode("0000000000000000") == [25645]
assert enc.encode("00000000000000000") == [8269, 10535, 830]
def test_simple_regex():
enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("rer") == [38149]
assert enc.encode("'rer") == [2351, 81]
assert enc.encode("today\n ") == [31213, 198, 220]
assert enc.encode("today\n \n") == [31213, 27907]
assert enc.encode("today\n \n") == [31213, 14211]
def test_basic_encode():
enc = tiktoken.get_encoding("r50k_base")
assert enc.encode("hello world") == [31373, 995]
enc = tiktoken.get_encoding("p50k_base")
assert enc.encode("hello world") == [31373, 995]
enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("hello world") == [15339, 1917]
assert enc.encode(" \x850") == [220, 126, 227, 15]
def test_encode_empty():
enc = tiktoken.get_encoding("r50k_base")
assert enc.encode("") == []
def test_encode_bytes():
enc = tiktoken.get_encoding("cl100k_base")
assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085]
def test_encode_surrogate_pairs():
enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("👍") == [9468, 239, 235]
# surrogate pair gets converted to codepoint
assert enc.encode("\ud83d\udc4d") == [9468, 239, 235]
# lone surrogate just gets replaced
assert enc.encode("\ud83d") == enc.encode("�")
# ====================
# Roundtrip
# ====================
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_basic_roundtrip(make_enc):
enc = make_enc()
for value in (
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件!12345",
):
assert value == enc.decode(enc.encode(value))
assert value == enc.decode(enc.encode_ordinary(value))
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
@hypothesis.given(text=st.text())
@hypothesis.settings(deadline=None)
def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text):
enc = make_enc()
assert text == enc.decode(enc.encode(text))
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]):
enc = make_enc()
for token in range(enc.n_vocab):
try:
token_bytes = enc.decode_single_token_bytes(token)
except KeyError:
continue
assert enc.encode_single_token(token_bytes) == token
# ====================
# Special tokens
# ====================
def test_special_token():
enc = tiktoken.get_encoding("cl100k_base")
eot = enc.encode_single_token("<|endoftext|>")
assert eot == enc.eot_token
fip = enc.encode_single_token("<|fim_prefix|>")
fim = enc.encode_single_token("<|fim_middle|>")
text = "<|endoftext|> hello <|fim_prefix|>"
assert eot not in enc.encode(text, disallowed_special=())
with pytest.raises(ValueError):
enc.encode(text)
with pytest.raises(ValueError):
enc.encode(text, disallowed_special="all")
with pytest.raises(ValueError):
enc.encode(text, disallowed_special={"<|endoftext|>"})
with pytest.raises(ValueError):
enc.encode(text, disallowed_special={"<|fim_prefix|>"})
text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>"
tokens = enc.encode(text, disallowed_special=())
assert eot not in tokens
assert fip not in tokens
assert fim not in tokens
tokens = enc.encode(text, allowed_special="all", disallowed_special=())
assert eot in tokens
assert fip in tokens
assert fim in tokens
tokens = enc.encode(text, allowed_special="all", disallowed_special="all")
assert eot in tokens
assert fip in tokens
assert fim in tokens
tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=())
assert eot not in tokens
assert fip in tokens
assert fim not in tokens
tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=())
assert eot in tokens
assert fip not in tokens
assert fim not in tokens
tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=())
assert eot not in tokens
assert fip not in tokens
assert fim in tokens
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
@hypothesis.given(text=st.text())
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
def test_hyp_special_ordinary(make_enc, text: str):
enc = make_enc()
assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=())
# ====================
# Batch encoding
# ====================
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]):
enc = make_enc()
text1 = "hello world"
text2 = "goodbye world"
assert enc.encode_batch([text1]) == [enc.encode(text1)]
assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)]
assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)]
assert enc.encode_ordinary_batch([text1, text2]) == [
enc.encode_ordinary(text1),
enc.encode_ordinary(text2),
]
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
@hypothesis.given(batch=st.lists(st.text()))
@hypothesis.settings(deadline=None)
def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch):
enc = make_enc()
encoded = enc.encode_batch(batch)
assert encoded == [enc.encode(t) for t in batch]
decoded = enc.decode_batch(encoded)
assert decoded == batch
|