import re |
import argparse |
import json |
import torch |
import torch.nn as nn |
from tqdm import tqdm |
from torch.nn import functional as F |
from gpt_p.model import DecoderTransformer |
from torch.optim.lr_scheduler import _LRScheduler |
import math |
from datasets import load_dataset |
import wandb |
torch.manual_seed(420) |
base_name = 'gpt-p_CHARS_CHAT_' |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
context_size = 256 |
batch_size = 128 |
max_iters = 50_000 |
learning_rate = 3e-5 |
eval_interval = 100 |
eval_iters = 20 |
n_embed = 384 |
n_layer = 6 |
n_head = 6 |
dropout = 0.2 |
mask_all_data = True |
use_scheduler = False |
dataset = load_dataset('Lichess/standard-chess-games', '2014-08', split='train') |
og_samples = list(filter(lambda x: 'eval' not in x, dataset['movetext'])) |
new_dataset = load_dataset('Lichess/standard-chess-games', '2024-07', split='train', data_files=[f'data/year=2024/month=07/train-{str(i).zfill(5)}-of-00384.parquet' for i in range(10)]) |
new_dataset = [re.sub('[0-9]+\.\.\.', '', re.sub('{[^\}]*}', '', foo)).replace(' ', ' ').replace(' ', ' ') for foo in dataset['movetext']] |
og_samples += new_dataset |
if mask_all_data: |
content = '\n'.join(list(filter(lambda x: 'eval' not in x, dataset['movetext']))) |
else: |
content = og_samples |
print('Data loaded') |
print('Training on ', len(content), 'characters. Good luck!') |
book = content |
if mask_all_data: |
characters = sorted(list(set(book))) |
else: |
characters = sorted(list(set('\n'.join(book)))) |
vocab_size = len(characters) |
class Tokenizer: |
def __init__(self, vocab): |
self.vocab = vocab |
self.stoi = {ch: idx for idx, ch in enumerate(vocab)} |
self.itos = {idx: ch for idx, ch in enumerate(vocab)} |
def encode(self, s): |
return [self.stoi[c] for c in s] |
def decode(self, i): |
return ''.join([self.itos[x] for x in i]) |
@classmethod |
def from_pretrained(cls, path): |
with open(path, 'r') as f: |
vocab = json.load(f) |
return cls(vocab) |
def save_pretrained(self, path): |
with open(path, 'w') as f: |
json.dump(self.vocab, f) |
tokenizer = Tokenizer(characters) |
encode = tokenizer.encode |
decode = tokenizer.decode |
if mask_all_data: |
data = torch.tensor(encode(book), dtype=torch.long) |
else: |
data = [torch.tensor(encode(s), dtype=torch.long) for s in book] |
max_len = max(len(x) for x in og_samples) |
context_size = min(context_size, max_len) |
n = int(0.8 * len(data)) |
train_data = data[:n] |
val_data = data[n:] |
'P': 1, 'N': 3, 'B': 3, 'R': 5, 'Q': 9, 'K': 0, |
'p': 1, 'n': 3, 'b': 3, 'r': 5, 'q': 9, 'k': 0 |
} |
def initialize_board(): |
"""Initializes the standard chessboard setup.""" |
return [ |
['r', 'n', 'b', 'q', 'k', 'b', 'n', 'r'], |
['p', 'p', 'p', 'p', 'p', 'p', 'p', 'p'], |
['.', '.', '.', '.', '.', '.', '.', '.'], |
['.', '.', '.', '.', '.', '.', '.', '.'], |
['.', '.', '.', '.', '.', '.', '.', '.'], |
['.', '.', '.', '.', '.', '.', '.', '.'], |
['P', 'P', 'P', 'P', 'P', 'P', 'P', 'P'], |
['R', 'N', 'B', 'Q', 'K', 'B', 'N', 'R'] |
] |
def get_piece(board, position): |
"""Returns the piece at a given board position (e.g., e4 -> 'P' or '.').""" |
col = ord(position[0]) - ord('a') |
row = 8 - int(position[1]) |
return board[row][col] |
def set_piece(board, position, piece): |
"""Sets a piece on the board at a given position.""" |
col = ord(position[0]) - ord('a') |
row = 8 - int(position[1]) |
board[row][col] = piece |
def validate_pawn_move(board, start, end, is_white_turn): |
"""Validates pawn movement including capturing, advancing, and promotion.""" |
start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1]) |
end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1]) |
pawn_direction = -1 if is_white_turn else 1 |
if start_col == end_col and board[end_row][end_col] == '.': |
if start_row + pawn_direction == end_row: |
return True |
if (is_white_turn and start_row == 6 or not is_white_turn and start_row == 1) and start_row + 2 * pawn_direction == end_row: |
return True |
if abs(start_col - end_col) == 1 and start_row + pawn_direction == end_row: |
target_piece = board[end_row][end_col] |
if (is_white_turn and target_piece.islower()) or (not is_white_turn and target_piece.isupper()): |
return True |
return False |
def validate_knight_move(start, end): |
"""Validates knight movement (L-shape).""" |
start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1]) |
end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1]) |
col_diff = abs(start_col - end_col) |
row_diff = abs(start_row - end_row) |
return (col_diff == 2 and row_diff == 1) or (col_diff == 1 and row_diff == 2) |
def validate_rook_move(board, start, end): |
"""Validates rook movement (straight lines along rank or file).""" |
start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1]) |
end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1]) |
if start_col != end_col and start_row != end_row: |
return False |
if start_col == end_col: |
step = 1 if end_row > start_row else -1 |
for row in range(start_row + step, end_row, step): |
if board[row][start_col] != '.': |
return False |
else: |
step = 1 if end_col > start_col else -1 |
for col in range(start_col + step, end_col, step): |
if board[start_row][col] != '.': |
return False |
return True |
def validate_bishop_move(board, start, end): |
"""Validates bishop movement (diagonals).""" |
start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1]) |
end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1]) |
if abs(start_col - end_col) != abs(start_row - end_row): |
return False |
col_step = 1 if end_col > start_col else -1 |
row_step = 1 if end_row > start_row else -1 |
col, row = start_col + col_step, start_row + row_step |
while col != end_col and row != end_row: |
if board[row][col] != '.': |
return False |
col += col_step |
row += row_step |
return True |
def validate_move(board, move, is_white_turn): |
"""Validates a move based on the current board state.""" |
if move == "O-O" or move == "O-O-O": |
return True |
piece_type = 'P' if move[0].islower() else move[0] |
start = move[-2:] |
end = move[-2:] |
if piece_type == 'P': |
return validate_pawn_move(board, start, end, is_white_turn) |
elif piece_type == 'N': |
return validate_knight_move(start, end) |
elif piece_type == 'R': |
return validate_rook_move(board, start, end) |
elif piece_type == 'B': |
return validate_bishop_move(board, start, end) |
return True |
def update_board(board, move, is_white_turn): |
"""Updates the board according to the move.""" |
start = move[-2:] |
end = move[-2:] |
piece = get_piece(board, start) |
set_piece(board, end, piece) |
set_piece(board, start, '.') |
return board |
def validate_pgn(pgn_string): |
""" |
Validates the PGN string format and chess move legality. |
""" |
move_pattern = r'([PNBRQK]?[a-h]?[1-8]?[x]?[a-h][1-8](=[QRNB])?|O-O(-O)?)[+#]?' |
result_pattern = r'(1-0|0-1|1/2-1/2)' |
tag_pattern = r'\[([A-Za-z0-9_]+)\s+"([^"]+)"\]' |
pgn_lines = pgn_string.strip().splitlines() |
tags = [line for line in pgn_lines if line.startswith('[')] |
for tag in tags: |
if not re.match(tag_pattern, tag): |
return False |
moves_section = ' '.join([line for line in pgn_lines if not line.startswith('[')]).strip() |
if not re.search(result_pattern, moves_section): |
return False |
moves_section = re.sub(result_pattern, '', moves_section).strip() |
board = initialize_board() |
is_white_turn = True |
move_tokens = re.split(r'\s|\d+\.', moves_section) |
for token in move_tokens: |
if token: |
if not re.match(move_pattern, token): |
return False |
if not validate_move(board, token, is_white_turn): |
return False |
board = update_board(board, token, is_white_turn) |
is_white_turn = not is_white_turn |
return True |
pgn_string = """ |
[Event "World Championship"] |
[Site "Moscow URS"] |
[Date "1985.11.09"] |
[Round "16"] |
[White "Kasparov, Garry"] |
[Black "Karpov, Anatoly"] |
[Result "1-0"] |
1. e4 e5 2. Nf3 Nc6 3. Bb5 a6 4. Ba4 Nf6 5. O-O Be7 6. Re1 b5 7. Bb3 d6 |
8. c3 O-O 9. h3 Nb8 10. d4 Nbd7 11. c4 Bb7 12. Nbd2 c6 13. Bc2 Re8 14. b3 Bf8 |
15. Bb2 Qc7 16. Rc1 Rad8 17. a3 Qb8 18. Bd3 g6 19. Qc2 Nh5 20. g3 Ng7 21. Qb1 |
exd4 22. Nxd4 c5 23. N4f3 Ne6 24. Bf1 Ne5 25. Qa1 Nxf3+ 26. Nxf3 Qa8 27. b4 |
Rc8 28. Bd3 Bh6 29. Rc2 Bc6 30. h4 f5 31. exf5 Bxf3 32. fxe6 Bh1 33. Bf1 Qf3 |
34. Re2 Bg7 35. Kh2 Rc7 36. Bxg7 Rxg7 37. Qf6 bxc4 38. e7 Qxf6 39. exf6 1-0 |
""" |
def get_batch_from_samples(split): |
data = train_data if split == 'train' else val_data |
sample_idx = torch.randint(len(data), (batch_size,)) |
inputs = [] |
outputs = [] |
space = encode(' ')[0] |
for idx in sample_idx: |
sample_size = len(data[idx]) |
start = torch.randint(max(sample_size - 2, sample_size - context_size), (1,)) |
end = start + context_size |
i1 = data[idx][start:end].tolist() |
i2 = [space] * (context_size - len(i1)) |
input_sample = torch.tensor(i1 + i2) |
o1 = data[idx][start+1:end+1].tolist() |
o2 = [space] * (context_size - len(o1)) |
output_sample = torch.tensor(o1 + o2) |
inputs.append(input_sample) |
outputs.append(output_sample) |
x = torch.stack(inputs) |
y = torch.stack(outputs) |
return x.to(device), y.to(device) |
def get_batch(split): |
data = train_data if split == 'train' else val_data |
idx = torch.randint(len(data) - context_size, (batch_size,)) |
x = torch.stack([data[i:i+context_size] for i in idx]) |
y = torch.stack([data[i+1:i+context_size+1] for i in idx]) |
return x.to(device), y.to(device) |
if not mask_all_data: |
get_batch = get_batch_from_samples |
def print_sample(input_value=None): |
if input_value is None: |
input_value = torch.zeros((1,1), dtype=torch.long, device=device) |
print('Validation sample:') |
sample = decode(model.generate(input_value, max_new_tokens=250, context_size=context_size)[0].tolist()) |
if '<E>' in sample: |
sample = sample[:sample.find('<E>') + 3] |
print(sample) |
@torch.no_grad() |
def estimate_loss(): |
out = {} |
model.eval() |
for split in ['train', 'val']: |
losses = torch.zeros(eval_iters) |
for k in range(eval_iters): |
X, Y = get_batch(split) |
logits, loss = model(X, Y) |
""" |
input_string = X[0].tolist() |
gen = model.generate(X[0].view(1, -1), max_new_tokens=5, context_size=context_size) |
o = tokenizer.decode(gen[0].tolist()) |
try: |
valid = int(not validate_pgn(o)) |
except Exception: |
valid = 2 |
""" |
losses[k] = loss.item() |
out[split] = losses.mean() |
input_string = '1. e4 g6 2.' |
print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string)))) |
model.train() |
return out |
class CosineAnnealingScheduler(_LRScheduler): |
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): |
""" |
Args: |
optimizer (Optimizer): Wrapped optimizer. |
T_max (int): Maximum number of iterations. |
eta_min (float): Minimum learning rate. Default: 0. |
last_epoch (int): The index of last epoch. Default: -1. |
""" |
self.T_max = T_max |
self.eta_min = eta_min |
super().__init__(optimizer, last_epoch) |
def get_lr(self): |
if not self._get_lr_called_within_step: |
warnings.warn("To get the last learning rate computed by the scheduler, " |
"please use `get_last_lr()`.", UserWarning) |
if self.last_epoch == 0: |
return [group['lr'] for group in self.optimizer.param_groups] |
elif self._step_count == 1 and self.last_epoch > 0: |
return [self.eta_min + (base_lr - self.eta_min) * |
(1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2 |
for base_lr in self.base_lrs] |
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: |
return [group['lr'] + (base_lr - self.eta_min) * |
(1 - math.cos(math.pi / self.T_max)) / 2 |
for base_lr, group in |
zip(self.base_lrs, self.optimizer.param_groups)] |
return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / |
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * |
(group['lr'] - self.eta_min) + self.eta_min |
for group in self.optimizer.param_groups] |
if __name__ == "__main__": |
args = argparse.ArgumentParser() |
args.add_argument('--load', '-l', action='store_true', default=False, help='Load model state.') |
args.add_argument('--inference', '-i', action='store_true', default=False, help='Run only inference') |
args = args.parse_args() |
params = {'vocab_size': vocab_size, 'n_embed': n_embed, 'context_size': context_size, 'n_layer': n_layer, 'n_head': n_head, 'dropout': dropout} |
if args.load: |
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) |
m.load_state_dict(torch.load(f'./models/{base_name}')) |
else: |
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) |
model = m.to(device) |
if args.inference: |
input_string = input('Enter a PGN string: ') |
print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string)))) |
with open(f'./models/{base_name}_params.json', 'w') as f: |
json.dump(params, f) |
tokenizer.save_pretrained(f'./models/{base_name}_vocab.json') |
exit() |
wandb.init(project='chessPT') |
wandb.watch(model) |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
if use_scheduler: |
scheduler = CosineAnnealingScheduler(optimizer, max_iters, eta_min=learning_rate//1e6) |
for step in tqdm(range(max_iters), total=max_iters, desc='Training'): |
if step % eval_interval == 0: |
losses = estimate_loss() |
if use_scheduler: |
print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}, lr: {scheduler.get_last_lr()[0]}') |
else: |
print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}') |
wandb.log({'train_loss': losses['train'], 'val_loss': losses['val']}) |
xb, yb = get_batch('train') |
logits, loss = model(xb, yb) |
""" |
input_string = xb[0].tolist() |
gen = model.generate(xb[0].view(1, -1), max_new_tokens=5, context_size=context_size) |
out = tokenizer.decode(gen[0].tolist()) |
try: |
valid = int(not validate_pgn(out)) |
except Exception: |
valid = 2 |
loss += valid |
""" |
if use_scheduler: |
wandb.log({'running_train_loss': loss.item(), 'lr': scheduler.get_last_lr()[0]}) |
else: |
wandb.log({'running_train_loss': loss.item()}) |
optimizer.zero_grad(set_to_none=True) |
loss.backward() |
optimizer.step() |
if use_scheduler: |
scheduler.step() |
print() |
print('Loss:') |
print(loss.item()) |
torch.save(model.state_dict(), f'./models/{base_name}') |
with open(f'./models/{base_name}_params.json', 'w') as f: |
json.dump(params, f) |
with open('train.log', 'a') as f: |
f.write(f'{max_iters},{learning_rate}\n') |