austindavis commited on
Commit
28f21ff
·
verified ·
1 Parent(s): bb83eaf

Update agents/uci_tokenizers.py

Browse files
Files changed (1) hide show
  1. agents/uci_tokenizers.py +133 -89
agents/uci_tokenizers.py CHANGED
@@ -1,40 +1,40 @@
1
  from typing import List
2
 
3
  import chess
4
- import tiktoken
 
5
  import tokenizers
6
  from tokenizers import models, pre_tokenizers, processors
7
  from torch import Tensor as TT
8
  from transformers import PreTrainedTokenizerFast
9
  from transformers.tokenization_utils_fast import BatchEncoding
10
 
11
-
12
- def getTiktokenizer() -> tiktoken.Encoding:
13
- """
14
- Defines a tiktoken-based BPE encoder for UCI chess moves. This
15
- tokenizer effectively tokenizes UCI moves by the square names.
16
- One notable variation is that promotions must be in upper-case.
17
-
18
- Vocabulary:
19
- Special Tokens (4): "\<|pad|\>", "\<|startoftext|\>", "\<|endoftext|\>", "\<|unknown|\>"
20
- Square Tokens (64): a1 through h8
21
- Promote Tokens (4): Q, B, R, N
22
- UNUSED (8120): Need 8192-4-64-4=8120 unused tokens of the form <|unused####|>
23
- """
24
- special_tokens = ["<|pad|>", "<|startoftext|>", "<|endoftext|>", "<|unknown|>"]
25
- unused_tokens = [f"<|unused{i:04d}" for i in range(8120)]
26
- chess_vocab = special_tokens + chess.SQUARE_NAMES + list("QBRN") + unused_tokens
27
- mergeable_ranks = {k.encode():v for (v,k) in enumerate(chess_vocab)}
28
- chess_pat_str = r'[a-h][1-8]|[QBRN]'
29
-
30
- enc = tiktoken.Encoding(
31
- name="chess_enc",
32
- pat_str=chess_pat_str, # or \d|\s
33
- mergeable_ranks=mergeable_ranks,
34
- special_tokens={k:v for (v,k) in enumerate(special_tokens)},
35
- )
36
-
37
- return enc
38
 
39
 
40
  class UciTokenizer(PreTrainedTokenizerFast):
@@ -42,7 +42,6 @@ class UciTokenizer(PreTrainedTokenizerFast):
42
  _UNK_TOKEN: str
43
  _EOS_TOKEN: str
44
  _BOS_TOKEN: str
45
-
46
 
47
  stoi: dict[str, int]
48
  """Integer to String mapping"""
@@ -59,11 +58,11 @@ class UciTokenizer(PreTrainedTokenizerFast):
59
  bos_token,
60
  eos_token,
61
  name_or_path,
62
- **kwargs
63
  ):
64
  self.stoi = stoi
65
  self.itos = itos
66
-
67
  self._PAD_TOKEN = pad_token
68
  self._UNK_TOKEN = unk_token
69
  self._EOS_TOKEN = eos_token
@@ -81,8 +80,8 @@ class UciTokenizer(PreTrainedTokenizerFast):
81
  pair=None,
82
  special_tokens=[(bos_token, 1)],
83
  )
84
- slow_tokenizer.post_processor=post_proc
85
-
86
  super().__init__(
87
  tokenizer_object=slow_tokenizer,
88
  unk_token=self._UNK_TOKEN,
@@ -90,7 +89,7 @@ class UciTokenizer(PreTrainedTokenizerFast):
90
  eos_token=self._EOS_TOKEN,
91
  pad_token=self._PAD_TOKEN,
92
  name_or_path=name_or_path,
93
- **kwargs
94
  )
95
 
96
  # Override the decode behavior to ensure spaces are correctly handled
@@ -108,47 +107,48 @@ class UciTokenizer(PreTrainedTokenizerFast):
108
 
109
  if isinstance(token_ids, TT):
110
  token_ids = token_ids.tolist()
111
-
112
  if isinstance(token_ids, list):
113
  tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids]
114
  processed_tokens = self._process_str_tokens(tokens_str)
115
 
116
  return " ".join(processed_tokens)
117
-
118
- raise ValueError(f"Unknown input type to decode() for argument 'token_ids'. Received: {type(token_ids)} ")
 
 
119
 
120
  self._decode = _decode
121
 
122
  def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer:
123
  raise NotImplementedError
124
 
125
- def _process_str_tokens(self, tokens_str: list[str], return_player_ids: bool) -> list[str]:
 
 
126
  raise NotImplementedError
127
-
128
  def get_id2square_list() -> list[int]:
129
  raise NotImplementedError
130
 
131
 
132
  class UciTileTokenizer(UciTokenizer):
133
- """ Uci tokenizer converting start/end tiles and promotion types each into individual tokens"""
134
 
135
- SPECIAL_TOKENS = ["<|pad|>", "<|startoftext|>", "<|endoftext|>", "<|unknown|>"]
 
 
 
 
 
136
 
137
- stoi = {
138
- tok: idx
139
- for tok, idx in list(
140
- zip(SPECIAL_TOKENS + chess.SQUARE_NAMES + list("QRBN"), range(72))
141
- )
142
- }
143
 
144
- itos = {
145
- idx: tok
146
- for tok, idx in list(
147
- zip(SPECIAL_TOKENS + chess.SQUARE_NAMES + list("QRBN"), range(72))
148
- )
149
- }
150
 
151
- id2square:List[int] = list(range(4,68))
152
  """
153
  List mapping token IDs to squares on the chess board. Order is file then rank, i.e.:
154
  `A1, B1, C1, ..., F8, G8, H8`
@@ -157,29 +157,63 @@ class UciTileTokenizer(UciTokenizer):
157
  def get_id2square_list(self) -> List[int]:
158
  return self.id2square
159
 
160
- def __init__(self, **kwargs):
161
- # Remove conflicting arguments from kwargs if they exist
162
  kwargs.pop("pad_token", None)
163
  kwargs.pop("unk_token", None)
164
  kwargs.pop("bos_token", None)
165
  kwargs.pop("eos_token", None)
166
  kwargs.pop("clean_up_tokenization_spaces", None)
167
  kwargs.pop("name_or_path", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  super().__init__(
169
  self.stoi,
170
  self.itos,
171
- pad_token="<|pad|>",
172
- unk_token="<|unknown|>",
173
- bos_token="<|startoftext|>",
174
- eos_token="<|endoftext|>",
175
  name_or_path="austindavis/uci_tile_tokenizer",
176
  clean_up_tokenization_spaces=False,
177
- **kwargs
178
  )
179
 
180
  def _init_pretokenizer(self):
181
  # Pre-tokenizer to split input into UCI moves
182
- pattern = tokenizers.Regex(r"\d|[QBRN]")
183
  pre_tokenizer = pre_tokenizers.Sequence(
184
  [
185
  pre_tokenizers.Whitespace(),
@@ -214,16 +248,16 @@ class UciTileTokenizer(UciTokenizer):
214
  return moves
215
 
216
  @staticmethod
217
- def compute_players(encoding: BatchEncoding, according_to='output'):
218
  """
219
- Determines which player (white=True, black=False) is associated with each token in the sequence.
220
  This method works based on chess move sequences tokenized using the UciTileTokenizer.
221
 
222
  # Parameters:
223
  ----------
224
  **`encoding`** : BatchEncoding
225
  Tokenized input of a chess game, where each token represents a move or special token.
226
-
227
  **`according_to`** : str (optional, default='output')
228
  Specifies the perspective for associating players:
229
  - 'output': Returns the player whose next move is predicted by the sequence (the output move).
@@ -233,12 +267,12 @@ class UciTileTokenizer(UciTokenizer):
233
  -------
234
  List[bool]
235
  A list of boolean values indicating the player for each token:
236
- - True for white (player 1),
237
  - False for black (player 2).
238
-
239
  The list length corresponds to the number of tokens in the sequence, including special tokens if any.
240
 
241
- # Example Usage:
242
  ```
243
  >>> tok = UciTileTokenizer()
244
  >>> encoding = tok('e2e4 d7d5 e4d5 e7e6 d5e6 d8g5 e6e7 g5f6 e7f8Q')
@@ -246,7 +280,7 @@ class UciTileTokenizer(UciTokenizer):
246
  [1, 16, 32, 55, 39, 32, 39, 56, 48, 39, 48, 63, 42, 48, 56, 42, 49, 56, 65, 68]
247
  >>> tok.compute_players(encoding)
248
  [True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, True, False]
249
- >>> tok.compute_players(encoding, according_to='input')
250
  [True, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, True]
251
  ```
252
 
@@ -256,29 +290,30 @@ class UciTileTokenizer(UciTokenizer):
256
  using `according_to='output'`, it cannot reliably predict which player is
257
  responsible for selecting the final token of the sequence. For instance,
258
  if a pawn is moved to the back rank (e.g., 'e7e8'), then white must select
259
- the promotion class on the next token; however, this algorithm will predict
260
- that black is responsible for selecting the next token instead of white.
261
  """
262
-
263
- return [UciTileTokenizer._compute_players_single(encoding[i].ids) for i in range(len(encoding['input_ids']))]
264
-
265
-
266
-
 
267
  @staticmethod
268
- def _compute_players_single(input_ids: list[int], according_to: str='output'):
269
  players = [] if according_to == "output" else [True]
270
  current_player = False
271
  num_tokens_in_ply = 0
272
  has_specials = False
273
-
274
  for i, token_id in enumerate(input_ids):
275
  if token_id == 1:
276
  has_specials = True
277
  continue
278
-
279
  if num_tokens_in_ply == 0:
280
  # check if promotion OR unknown token ID
281
- if token_id > 67 or token_id == 3:
282
  players.append(current_player)
283
  num_tokens_in_ply = 0
284
  else:
@@ -304,17 +339,26 @@ class UciTileTokenizer(UciTokenizer):
304
 
305
  return players if has_specials else players[1:]
306
 
 
307
  if __name__ == "__main__":
308
  tok = UciTileTokenizer()
309
- encoding = tok('e2e4Q b7b8N e2e7 a1',add_special_tokens=True)
310
- print(f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='output')=}")
311
- print(f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='input')=}")
 
 
 
 
312
 
313
- encoding = tok('e2e4Q b7b8N e2e7 a1',add_special_tokens=False)
314
- print(f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='output')=}")
315
- print(f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='input')=}")
 
 
 
 
316
 
317
- encoding = tok('e2e4 d7d5 e4d5 e7e6 d5e6 d8g5 e6e7 g5f6 e7f8Q')
318
- print(encoding['input_ids'])
319
  print(tok.compute_players(encoding))
320
- print(tok.compute_players(encoding, according_to='input'))
 
1
  from typing import List
2
 
3
  import chess
4
+
5
+ # import tiktoken
6
  import tokenizers
7
  from tokenizers import models, pre_tokenizers, processors
8
  from torch import Tensor as TT
9
  from transformers import PreTrainedTokenizerFast
10
  from transformers.tokenization_utils_fast import BatchEncoding
11
 
12
+ # def getTiktokenizer() -> tiktoken.Encoding:
13
+ # """
14
+ # Defines a tiktoken-based BPE encoder for UCI chess moves. This
15
+ # tokenizer effectively tokenizes UCI moves by the square names.
16
+ # One notable variation is that promotions must be in upper-case.
17
+
18
+ # Vocabulary:
19
+ # Special Tokens (4): "\<|pad|\>", "\<|startoftext|\>", "\<|endoftext|\>", "\<|unknown|\>"
20
+ # Square Tokens (64): a1 through h8
21
+ # Promote Tokens (4): Q, B, R, N
22
+ # UNUSED (8120): Need 8192-4-64-4=8120 unused tokens of the form <|unused####|>
23
+ # """
24
+ # special_tokens = ["<|pad|>", "<|startoftext|>", "<|endoftext|>", "<|unknown|>"]
25
+ # unused_tokens = [f"<|unused{i:04d}" for i in range(8120)]
26
+ # chess_vocab = special_tokens + chess.SQUARE_NAMES + list("QBRN") + unused_tokens
27
+ # mergeable_ranks = {k.encode():v for (v,k) in enumerate(chess_vocab)}
28
+ # chess_pat_str = r'[a-h][1-8]|[QBRN]'
29
+
30
+ # enc = tiktoken.Encoding(
31
+ # name="chess_enc",
32
+ # pat_str=chess_pat_str, # or \d|\s
33
+ # mergeable_ranks=mergeable_ranks,
34
+ # special_tokens={k:v for (v,k) in enumerate(special_tokens)},
35
+ # )
36
+
37
+ # return enc
 
38
 
39
 
40
  class UciTokenizer(PreTrainedTokenizerFast):
 
42
  _UNK_TOKEN: str
43
  _EOS_TOKEN: str
44
  _BOS_TOKEN: str
 
45
 
46
  stoi: dict[str, int]
47
  """Integer to String mapping"""
 
58
  bos_token,
59
  eos_token,
60
  name_or_path,
61
+ **kwargs,
62
  ):
63
  self.stoi = stoi
64
  self.itos = itos
65
+
66
  self._PAD_TOKEN = pad_token
67
  self._UNK_TOKEN = unk_token
68
  self._EOS_TOKEN = eos_token
 
80
  pair=None,
81
  special_tokens=[(bos_token, 1)],
82
  )
83
+ slow_tokenizer.post_processor = post_proc
84
+
85
  super().__init__(
86
  tokenizer_object=slow_tokenizer,
87
  unk_token=self._UNK_TOKEN,
 
89
  eos_token=self._EOS_TOKEN,
90
  pad_token=self._PAD_TOKEN,
91
  name_or_path=name_or_path,
92
+ **kwargs,
93
  )
94
 
95
  # Override the decode behavior to ensure spaces are correctly handled
 
107
 
108
  if isinstance(token_ids, TT):
109
  token_ids = token_ids.tolist()
110
+
111
  if isinstance(token_ids, list):
112
  tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids]
113
  processed_tokens = self._process_str_tokens(tokens_str)
114
 
115
  return " ".join(processed_tokens)
116
+
117
+ raise ValueError(
118
+ f"Unknown input type to decode() for argument 'token_ids'. Received: {type(token_ids)} "
119
+ )
120
 
121
  self._decode = _decode
122
 
123
  def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer:
124
  raise NotImplementedError
125
 
126
+ def _process_str_tokens(
127
+ self, tokens_str: list[str], return_player_ids: bool
128
+ ) -> list[str]:
129
  raise NotImplementedError
130
+
131
  def get_id2square_list() -> list[int]:
132
  raise NotImplementedError
133
 
134
 
135
  class UciTileTokenizer(UciTokenizer):
136
+ """Uci tokenizer converting start/end tiles and promotion types each into individual tokens"""
137
 
138
+ SPECIAL_TOKENS = (_PAD_TOKEN, _UNK_TOKEN, _BOS_TOKEN, _EOS_TOKEN) = [
139
+ "<|pad|>",
140
+ "<|startoftext|>",
141
+ "<|endoftext|>",
142
+ "<|unknown|>",
143
+ ]
144
 
145
+ stoi: dict[str, int]
146
+ itos: dict[int, str]
 
 
 
 
147
 
148
+ _split_regex: str
149
+ _promote_chars: str
 
 
 
 
150
 
151
+ id2square: List[int] = list(range(4, 68))
152
  """
153
  List mapping token IDs to squares on the chess board. Order is file then rank, i.e.:
154
  `A1, B1, C1, ..., F8, G8, H8`
 
157
  def get_id2square_list(self) -> List[int]:
158
  return self.id2square
159
 
160
+ def __init__(self, *, upper_promotions: bool, **kwargs):
161
+ # Remove conflicting arguments from kwargs if they exist
162
  kwargs.pop("pad_token", None)
163
  kwargs.pop("unk_token", None)
164
  kwargs.pop("bos_token", None)
165
  kwargs.pop("eos_token", None)
166
  kwargs.pop("clean_up_tokenization_spaces", None)
167
  kwargs.pop("name_or_path", None)
168
+
169
+ self.upper_promotions = upper_promotions
170
+
171
+ if upper_promotions:
172
+ self._promote_chars = "QRBN"
173
+ self._split_regex = r"[a-h][1-8]|[QRBN]"
174
+ else:
175
+ self._promote_chars = "qrbn"
176
+ self._split_regex = r"[a-h][1-8]|[qrnb]"
177
+
178
+ self.stoi = {
179
+ tok: idx
180
+ for tok, idx in list(
181
+ zip(
182
+ self.SPECIAL_TOKENS
183
+ + chess.SQUARE_NAMES
184
+ + list(self._promote_chars),
185
+ range(72),
186
+ )
187
+ )
188
+ }
189
+
190
+ self.itos = {
191
+ idx: tok
192
+ for tok, idx in list(
193
+ zip(
194
+ self.SPECIAL_TOKENS
195
+ + chess.SQUARE_NAMES
196
+ + list(self._promote_chars),
197
+ range(72),
198
+ )
199
+ )
200
+ }
201
+
202
  super().__init__(
203
  self.stoi,
204
  self.itos,
205
+ pad_token=self._PAD_TOKEN,
206
+ unk_token=self._UNK_TOKEN,
207
+ bos_token=self._BOS_TOKEN,
208
+ eos_token=self._EOS_TOKEN,
209
  name_or_path="austindavis/uci_tile_tokenizer",
210
  clean_up_tokenization_spaces=False,
211
+ **kwargs,
212
  )
213
 
214
  def _init_pretokenizer(self):
215
  # Pre-tokenizer to split input into UCI moves
216
+ pattern = tokenizers.Regex(self._split_regex)
217
  pre_tokenizer = pre_tokenizers.Sequence(
218
  [
219
  pre_tokenizers.Whitespace(),
 
248
  return moves
249
 
250
  @staticmethod
251
+ def compute_players(encoding: BatchEncoding, according_to="output"):
252
  """
253
+ Determines which player (white=True, black=False) is associated with each token in the sequence.
254
  This method works based on chess move sequences tokenized using the UciTileTokenizer.
255
 
256
  # Parameters:
257
  ----------
258
  **`encoding`** : BatchEncoding
259
  Tokenized input of a chess game, where each token represents a move or special token.
260
+
261
  **`according_to`** : str (optional, default='output')
262
  Specifies the perspective for associating players:
263
  - 'output': Returns the player whose next move is predicted by the sequence (the output move).
 
267
  -------
268
  List[bool]
269
  A list of boolean values indicating the player for each token:
270
+ - True for white (player 1),
271
  - False for black (player 2).
272
+
273
  The list length corresponds to the number of tokens in the sequence, including special tokens if any.
274
 
275
+ # Example Usage:
276
  ```
277
  >>> tok = UciTileTokenizer()
278
  >>> encoding = tok('e2e4 d7d5 e4d5 e7e6 d5e6 d8g5 e6e7 g5f6 e7f8Q')
 
280
  [1, 16, 32, 55, 39, 32, 39, 56, 48, 39, 48, 63, 42, 48, 56, 42, 49, 56, 65, 68]
281
  >>> tok.compute_players(encoding)
282
  [True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, True, False]
283
+ >>> tok.compute_players(encoding, according_to='input')
284
  [True, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, True]
285
  ```
286
 
 
290
  using `according_to='output'`, it cannot reliably predict which player is
291
  responsible for selecting the final token of the sequence. For instance,
292
  if a pawn is moved to the back rank (e.g., 'e7e8'), then white must select
293
+ the promotion class on the next token; however, this algorithm will predict
294
+ that black is responsible for selecting the next token instead of white.
295
  """
296
+
297
+ return [
298
+ UciTileTokenizer._compute_players_single(encoding[i].ids)
299
+ for i in range(len(encoding["input_ids"]))
300
+ ]
301
+
302
  @staticmethod
303
+ def _compute_players_single(input_ids: list[int], according_to: str = "output"):
304
  players = [] if according_to == "output" else [True]
305
  current_player = False
306
  num_tokens_in_ply = 0
307
  has_specials = False
308
+
309
  for i, token_id in enumerate(input_ids):
310
  if token_id == 1:
311
  has_specials = True
312
  continue
313
+
314
  if num_tokens_in_ply == 0:
315
  # check if promotion OR unknown token ID
316
+ if token_id > 67 or token_id == 3:
317
  players.append(current_player)
318
  num_tokens_in_ply = 0
319
  else:
 
339
 
340
  return players if has_specials else players[1:]
341
 
342
+
343
  if __name__ == "__main__":
344
  tok = UciTileTokenizer()
345
+ encoding = tok("e2e4Q b7b8N e2e7 a1", add_special_tokens=True)
346
+ print(
347
+ f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='output')=}"
348
+ )
349
+ print(
350
+ f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='input')=}"
351
+ )
352
 
353
+ encoding = tok("e2e4Q b7b8N e2e7 a1", add_special_tokens=False)
354
+ print(
355
+ f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='output')=}"
356
+ )
357
+ print(
358
+ f"{encoding['input_ids']=}\n{tok.compute_players(encoding, according_to='input')=}"
359
+ )
360
 
361
+ encoding = tok("e2e4 d7d5 e4d5 e7e6 d5e6 d8g5 e6e7 g5f6 e7f8Q")
362
+ print(encoding["input_ids"])
363
  print(tok.compute_players(encoding))
364
+ print(tok.compute_players(encoding, according_to="input"))