|
import re |
|
import argparse |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
from tqdm import tqdm |
|
from torch.nn import functional as F |
|
from gpt_p.model import DecoderTransformer |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
import math |
|
from datasets import load_dataset |
|
import wandb |
|
|
|
|
|
torch.manual_seed(420) |
|
|
|
base_name = 'gpt-p_CHARS_CHAT_' |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
context_size = 256 |
|
batch_size = 128 |
|
max_iters = 50_000 |
|
learning_rate = 3e-5 |
|
eval_interval = 100 |
|
eval_iters = 20 |
|
n_embed = 384 |
|
n_layer = 6 |
|
n_head = 6 |
|
dropout = 0.2 |
|
|
|
mask_all_data = True |
|
use_scheduler = False |
|
|
|
dataset = load_dataset('Lichess/standard-chess-games', '2014-08', split='train') |
|
og_samples = list(filter(lambda x: 'eval' not in x, dataset['movetext'])) |
|
|
|
|
|
new_dataset = load_dataset('Lichess/standard-chess-games', '2024-07', split='train', data_files=[f'data/year=2024/month=07/train-{str(i).zfill(5)}-of-00384.parquet' for i in range(10)]) |
|
|
|
new_dataset = [re.sub('[0-9]+\.\.\.', '', re.sub('{[^\}]*}', '', foo)).replace(' ', ' ').replace(' ', ' ') for foo in dataset['movetext']] |
|
|
|
og_samples += new_dataset |
|
|
|
if mask_all_data: |
|
content = '\n'.join(list(filter(lambda x: 'eval' not in x, dataset['movetext']))) |
|
else: |
|
content = og_samples |
|
|
|
print('Data loaded') |
|
print('Training on ', len(content), 'characters. Good luck!') |
|
|
|
|
|
|
|
|
|
|
|
|
|
book = content |
|
if mask_all_data: |
|
characters = sorted(list(set(book))) |
|
else: |
|
characters = sorted(list(set('\n'.join(book)))) |
|
vocab_size = len(characters) |
|
|
|
|
|
class Tokenizer: |
|
def __init__(self, vocab): |
|
self.vocab = vocab |
|
self.stoi = {ch: idx for idx, ch in enumerate(vocab)} |
|
self.itos = {idx: ch for idx, ch in enumerate(vocab)} |
|
|
|
def encode(self, s): |
|
return [self.stoi[c] for c in s] |
|
|
|
def decode(self, i): |
|
return ''.join([self.itos[x] for x in i]) |
|
|
|
@classmethod |
|
def from_pretrained(cls, path): |
|
with open(path, 'r') as f: |
|
vocab = json.load(f) |
|
return cls(vocab) |
|
|
|
def save_pretrained(self, path): |
|
with open(path, 'w') as f: |
|
json.dump(self.vocab, f) |
|
|
|
|
|
tokenizer = Tokenizer(characters) |
|
encode = tokenizer.encode |
|
decode = tokenizer.decode |
|
|
|
if mask_all_data: |
|
data = torch.tensor(encode(book), dtype=torch.long) |
|
else: |
|
data = [torch.tensor(encode(s), dtype=torch.long) for s in book] |
|
max_len = max(len(x) for x in og_samples) |
|
context_size = min(context_size, max_len) |
|
|
|
|
|
n = int(0.8 * len(data)) |
|
train_data = data[:n] |
|
val_data = data[n:] |
|
|
|
|
|
|
|
|
|
PIECE_VALUES = { |
|
'P': 1, 'N': 3, 'B': 3, 'R': 5, 'Q': 9, 'K': 0, |
|
'p': 1, 'n': 3, 'b': 3, 'r': 5, 'q': 9, 'k': 0 |
|
} |
|
|
|
def initialize_board(): |
|
"""Initializes the standard chessboard setup.""" |
|
return [ |
|
['r', 'n', 'b', 'q', 'k', 'b', 'n', 'r'], |
|
['p', 'p', 'p', 'p', 'p', 'p', 'p', 'p'], |
|
['.', '.', '.', '.', '.', '.', '.', '.'], |
|
['.', '.', '.', '.', '.', '.', '.', '.'], |
|
['.', '.', '.', '.', '.', '.', '.', '.'], |
|
['.', '.', '.', '.', '.', '.', '.', '.'], |
|
['P', 'P', 'P', 'P', 'P', 'P', 'P', 'P'], |
|
['R', 'N', 'B', 'Q', 'K', 'B', 'N', 'R'] |
|
] |
|
|
|
def get_piece(board, position): |
|
"""Returns the piece at a given board position (e.g., e4 -> 'P' or '.').""" |
|
col = ord(position[0]) - ord('a') |
|
row = 8 - int(position[1]) |
|
return board[row][col] |
|
|
|
def set_piece(board, position, piece): |
|
"""Sets a piece on the board at a given position.""" |
|
col = ord(position[0]) - ord('a') |
|
row = 8 - int(position[1]) |
|
board[row][col] = piece |
|
|
|
def validate_pawn_move(board, start, end, is_white_turn): |
|
"""Validates pawn movement including capturing, advancing, and promotion.""" |
|
start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1]) |
|
end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1]) |
|
|
|
pawn_direction = -1 if is_white_turn else 1 |
|
|
|
|
|
if start_col == end_col and board[end_row][end_col] == '.': |
|
if start_row + pawn_direction == end_row: |
|
return True |
|
if (is_white_turn and start_row == 6 or not is_white_turn and start_row == 1) and start_row + 2 * pawn_direction == end_row: |
|
return True |
|
|
|
|
|
if abs(start_col - end_col) == 1 and start_row + pawn_direction == end_row: |
|
target_piece = board[end_row][end_col] |
|
if (is_white_turn and target_piece.islower()) or (not is_white_turn and target_piece.isupper()): |
|
return True |
|
|
|
return False |
|
|
|
def validate_knight_move(start, end): |
|
"""Validates knight movement (L-shape).""" |
|
start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1]) |
|
end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1]) |
|
|
|
col_diff = abs(start_col - end_col) |
|
row_diff = abs(start_row - end_row) |
|
|
|
return (col_diff == 2 and row_diff == 1) or (col_diff == 1 and row_diff == 2) |
|
|
|
def validate_rook_move(board, start, end): |
|
"""Validates rook movement (straight lines along rank or file).""" |
|
start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1]) |
|
end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1]) |
|
|
|
if start_col != end_col and start_row != end_row: |
|
return False |
|
|
|
|
|
if start_col == end_col: |
|
step = 1 if end_row > start_row else -1 |
|
for row in range(start_row + step, end_row, step): |
|
if board[row][start_col] != '.': |
|
return False |
|
else: |
|
step = 1 if end_col > start_col else -1 |
|
for col in range(start_col + step, end_col, step): |
|
if board[start_row][col] != '.': |
|
return False |
|
|
|
return True |
|
|
|
def validate_bishop_move(board, start, end): |
|
"""Validates bishop movement (diagonals).""" |
|
start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1]) |
|
end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1]) |
|
|
|
if abs(start_col - end_col) != abs(start_row - end_row): |
|
return False |
|
|
|
|
|
col_step = 1 if end_col > start_col else -1 |
|
row_step = 1 if end_row > start_row else -1 |
|
col, row = start_col + col_step, start_row + row_step |
|
while col != end_col and row != end_row: |
|
if board[row][col] != '.': |
|
return False |
|
col += col_step |
|
row += row_step |
|
|
|
return True |
|
|
|
def validate_move(board, move, is_white_turn): |
|
"""Validates a move based on the current board state.""" |
|
if move == "O-O" or move == "O-O-O": |
|
return True |
|
|
|
piece_type = 'P' if move[0].islower() else move[0] |
|
start = move[-2:] |
|
end = move[-2:] |
|
|
|
if piece_type == 'P': |
|
return validate_pawn_move(board, start, end, is_white_turn) |
|
elif piece_type == 'N': |
|
return validate_knight_move(start, end) |
|
elif piece_type == 'R': |
|
return validate_rook_move(board, start, end) |
|
elif piece_type == 'B': |
|
return validate_bishop_move(board, start, end) |
|
|
|
|
|
return True |
|
|
|
def update_board(board, move, is_white_turn): |
|
"""Updates the board according to the move.""" |
|
start = move[-2:] |
|
end = move[-2:] |
|
piece = get_piece(board, start) |
|
|
|
|
|
set_piece(board, end, piece) |
|
set_piece(board, start, '.') |
|
|
|
return board |
|
|
|
def validate_pgn(pgn_string): |
|
""" |
|
Validates the PGN string format and chess move legality. |
|
""" |
|
|
|
move_pattern = r'([PNBRQK]?[a-h]?[1-8]?[x]?[a-h][1-8](=[QRNB])?|O-O(-O)?)[+#]?' |
|
result_pattern = r'(1-0|0-1|1/2-1/2)' |
|
tag_pattern = r'\[([A-Za-z0-9_]+)\s+"([^"]+)"\]' |
|
|
|
pgn_lines = pgn_string.strip().splitlines() |
|
|
|
tags = [line for line in pgn_lines if line.startswith('[')] |
|
for tag in tags: |
|
if not re.match(tag_pattern, tag): |
|
return False |
|
|
|
moves_section = ' '.join([line for line in pgn_lines if not line.startswith('[')]).strip() |
|
|
|
if not re.search(result_pattern, moves_section): |
|
return False |
|
|
|
moves_section = re.sub(result_pattern, '', moves_section).strip() |
|
|
|
board = initialize_board() |
|
is_white_turn = True |
|
|
|
move_tokens = re.split(r'\s|\d+\.', moves_section) |
|
for token in move_tokens: |
|
if token: |
|
if not re.match(move_pattern, token): |
|
return False |
|
|
|
if not validate_move(board, token, is_white_turn): |
|
return False |
|
|
|
board = update_board(board, token, is_white_turn) |
|
is_white_turn = not is_white_turn |
|
|
|
return True |
|
|
|
|
|
pgn_string = """ |
|
[Event "World Championship"] |
|
[Site "Moscow URS"] |
|
[Date "1985.11.09"] |
|
[Round "16"] |
|
[White "Kasparov, Garry"] |
|
[Black "Karpov, Anatoly"] |
|
[Result "1-0"] |
|
|
|
1. e4 e5 2. Nf3 Nc6 3. Bb5 a6 4. Ba4 Nf6 5. O-O Be7 6. Re1 b5 7. Bb3 d6 |
|
8. c3 O-O 9. h3 Nb8 10. d4 Nbd7 11. c4 Bb7 12. Nbd2 c6 13. Bc2 Re8 14. b3 Bf8 |
|
15. Bb2 Qc7 16. Rc1 Rad8 17. a3 Qb8 18. Bd3 g6 19. Qc2 Nh5 20. g3 Ng7 21. Qb1 |
|
exd4 22. Nxd4 c5 23. N4f3 Ne6 24. Bf1 Ne5 25. Qa1 Nxf3+ 26. Nxf3 Qa8 27. b4 |
|
Rc8 28. Bd3 Bh6 29. Rc2 Bc6 30. h4 f5 31. exf5 Bxf3 32. fxe6 Bh1 33. Bf1 Qf3 |
|
34. Re2 Bg7 35. Kh2 Rc7 36. Bxg7 Rxg7 37. Qf6 bxc4 38. e7 Qxf6 39. exf6 1-0 |
|
""" |
|
|
|
|
|
|
|
def get_batch_from_samples(split): |
|
data = train_data if split == 'train' else val_data |
|
sample_idx = torch.randint(len(data), (batch_size,)) |
|
inputs = [] |
|
outputs = [] |
|
space = encode(' ')[0] |
|
for idx in sample_idx: |
|
sample_size = len(data[idx]) |
|
start = torch.randint(max(sample_size - 2, sample_size - context_size), (1,)) |
|
end = start + context_size |
|
i1 = data[idx][start:end].tolist() |
|
i2 = [space] * (context_size - len(i1)) |
|
input_sample = torch.tensor(i1 + i2) |
|
o1 = data[idx][start+1:end+1].tolist() |
|
o2 = [space] * (context_size - len(o1)) |
|
output_sample = torch.tensor(o1 + o2) |
|
|
|
inputs.append(input_sample) |
|
outputs.append(output_sample) |
|
|
|
x = torch.stack(inputs) |
|
y = torch.stack(outputs) |
|
return x.to(device), y.to(device) |
|
|
|
|
|
def get_batch(split): |
|
data = train_data if split == 'train' else val_data |
|
idx = torch.randint(len(data) - context_size, (batch_size,)) |
|
x = torch.stack([data[i:i+context_size] for i in idx]) |
|
y = torch.stack([data[i+1:i+context_size+1] for i in idx]) |
|
return x.to(device), y.to(device) |
|
|
|
if not mask_all_data: |
|
get_batch = get_batch_from_samples |
|
|
|
|
|
|
|
|
|
def print_sample(input_value=None): |
|
if input_value is None: |
|
input_value = torch.zeros((1,1), dtype=torch.long, device=device) |
|
print('Validation sample:') |
|
sample = decode(model.generate(input_value, max_new_tokens=250, context_size=context_size)[0].tolist()) |
|
if '<E>' in sample: |
|
sample = sample[:sample.find('<E>') + 3] |
|
print(sample) |
|
|
|
|
|
@torch.no_grad() |
|
def estimate_loss(): |
|
out = {} |
|
model.eval() |
|
for split in ['train', 'val']: |
|
losses = torch.zeros(eval_iters) |
|
for k in range(eval_iters): |
|
X, Y = get_batch(split) |
|
logits, loss = model(X, Y) |
|
""" |
|
input_string = X[0].tolist() |
|
gen = model.generate(X[0].view(1, -1), max_new_tokens=5, context_size=context_size) |
|
o = tokenizer.decode(gen[0].tolist()) |
|
try: |
|
valid = int(not validate_pgn(o)) |
|
except Exception: |
|
valid = 2 |
|
""" |
|
losses[k] = loss.item() |
|
out[split] = losses.mean() |
|
|
|
input_string = '1. e4 g6 2.' |
|
print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string)))) |
|
model.train() |
|
return out |
|
|
|
|
|
class CosineAnnealingScheduler(_LRScheduler): |
|
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): |
|
""" |
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
T_max (int): Maximum number of iterations. |
|
eta_min (float): Minimum learning rate. Default: 0. |
|
last_epoch (int): The index of last epoch. Default: -1. |
|
""" |
|
self.T_max = T_max |
|
self.eta_min = eta_min |
|
super().__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
if not self._get_lr_called_within_step: |
|
warnings.warn("To get the last learning rate computed by the scheduler, " |
|
"please use `get_last_lr()`.", UserWarning) |
|
|
|
if self.last_epoch == 0: |
|
return [group['lr'] for group in self.optimizer.param_groups] |
|
elif self._step_count == 1 and self.last_epoch > 0: |
|
return [self.eta_min + (base_lr - self.eta_min) * |
|
(1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2 |
|
for base_lr in self.base_lrs] |
|
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: |
|
return [group['lr'] + (base_lr - self.eta_min) * |
|
(1 - math.cos(math.pi / self.T_max)) / 2 |
|
for base_lr, group in |
|
zip(self.base_lrs, self.optimizer.param_groups)] |
|
return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / |
|
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * |
|
(group['lr'] - self.eta_min) + self.eta_min |
|
for group in self.optimizer.param_groups] |
|
|
|
if __name__ == "__main__": |
|
args = argparse.ArgumentParser() |
|
args.add_argument('--load', '-l', action='store_true', default=False, help='Load model state.') |
|
args.add_argument('--inference', '-i', action='store_true', default=False, help='Run only inference') |
|
|
|
args = args.parse_args() |
|
|
|
params = {'vocab_size': vocab_size, 'n_embed': n_embed, 'context_size': context_size, 'n_layer': n_layer, 'n_head': n_head, 'dropout': dropout} |
|
if args.load: |
|
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) |
|
m.load_state_dict(torch.load(f'./models/{base_name}')) |
|
else: |
|
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) |
|
model = m.to(device) |
|
|
|
if args.inference: |
|
input_string = input('Enter a PGN string: ') |
|
print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string)))) |
|
with open(f'./models/{base_name}_params.json', 'w') as f: |
|
json.dump(params, f) |
|
|
|
tokenizer.save_pretrained(f'./models/{base_name}_vocab.json') |
|
exit() |
|
|
|
|
|
wandb.init(project='chessPT') |
|
|
|
wandb.watch(model) |
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
|
if use_scheduler: |
|
scheduler = CosineAnnealingScheduler(optimizer, max_iters, eta_min=learning_rate//1e6) |
|
|
|
for step in tqdm(range(max_iters), total=max_iters, desc='Training'): |
|
if step % eval_interval == 0: |
|
losses = estimate_loss() |
|
if use_scheduler: |
|
print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}, lr: {scheduler.get_last_lr()[0]}') |
|
else: |
|
print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}') |
|
wandb.log({'train_loss': losses['train'], 'val_loss': losses['val']}) |
|
|
|
xb, yb = get_batch('train') |
|
|
|
logits, loss = model(xb, yb) |
|
""" |
|
|
|
input_string = xb[0].tolist() |
|
gen = model.generate(xb[0].view(1, -1), max_new_tokens=5, context_size=context_size) |
|
out = tokenizer.decode(gen[0].tolist()) |
|
try: |
|
valid = int(not validate_pgn(out)) |
|
except Exception: |
|
valid = 2 |
|
loss += valid |
|
""" |
|
|
|
if use_scheduler: |
|
wandb.log({'running_train_loss': loss.item(), 'lr': scheduler.get_last_lr()[0]}) |
|
else: |
|
wandb.log({'running_train_loss': loss.item()}) |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
loss.backward() |
|
optimizer.step() |
|
if use_scheduler: |
|
scheduler.step() |
|
|
|
print() |
|
print('Loss:') |
|
print(loss.item()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), f'./models/{base_name}') |
|
with open(f'./models/{base_name}_params.json', 'w') as f: |
|
json.dump(params, f) |
|
with open('train.log', 'a') as f: |
|
f.write(f'{max_iters},{learning_rate}\n') |
|
|