chessPT / train.py
philipp-zettl's picture
Create train.py
c5454e8 verified
raw
history blame
4.19 kB
import argparse
import torch
import torch.nn as nn
from torch.nn import functional as F
from gpt_p.model import DecoderTransformer
from datasets import load_dataset
torch.manual_seed(420) # 1337
base_name = 'gpt-p_CHARS_CHAT_'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
context_size = 256 # how many tokens to consider while generating the next
batch_size = 128 # how many independent sequences will we process in parallel
max_iters = 30_000
learning_rate = 3e-5
eval_interval = 100
eval_iters = 20 # number evaluation iterations
n_embed = 384 # embedding size
n_layer = 6 # number of transformer layers
n_head = 6
dropout = 0.2 # dropout factor
dataset = load_dataset('Lichess/standard-chess-games', split='train')
content = '\n'.join(list(filter(lambda x: 'eval' not in x, dataset['movetext'])))
## BUILD DATA SET ##
book = content
characters = sorted(list(set(book)))
vocab_size = len(characters)
# convert
stoi = {ch: idx for idx, ch in enumerate(characters)}
itos = {idx: ch for idx, ch in enumerate(characters)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda i: ''.join([itos[x] for x in i])
data = torch.tensor(encode(book), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
def get_batch(split):
data = train_data if split == 'train' else val_data
idx = torch.randint(len(data) - context_size, (batch_size,))
x = torch.stack([data[i:i+context_size] for i in idx])
y = torch.stack([data[i+1:i+context_size+1] for i in idx])
return x.to(device), y.to(device)
## END BUILD DATA SET ##
## MODEL DEFINITION ##
def print_sample(input_value=None):
if input_value is None:
input_value = torch.zeros((1,1), dtype=torch.long, device=device)
print('Validation sample:')
sample = decode(model.generate(input_value, max_new_tokens=250, context_size=context_size)[0].tolist())
if '<E>' in sample:
sample = sample[:sample.find('<E>') + 3]
print(sample)
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
input_string = '1. e4 g6'
print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string))))
model.train()
return out
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument('--load', '-l', action='store_true', default=False, help='Load model state.')
args.add_argument('--inference', '-i', action='store_true', default=False, help='Run only inference')
args = args.parse_args()
params = {'vocab_size': vocab_size, 'n_embed': n_embed, 'context_size': context_size, 'n_layer': n_layer, 'n_head': n_head, 'dropout': dropout}
if args.load:
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
m.load_state_dict(torch.load(f'./models/{base_name}' + ''.join(f'{key}={v}' for key, v in params.items())))
else:
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model = m.to(device)
if args.inference:
exit()
## END MODEL ##
## START TRAINING ##
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for step in range(max_iters):
if step % eval_interval == 0:
losses = estimate_loss()
print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}')
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
print()
print('Loss:')
print(loss.item())
## END TRAINING ##
## START VALIDATION ##
## END VALIDATION ##
# save model weights
torch.save(model.state_dict(), f'./models/{base_name}' + ''.join([f'{key}={v}' for key, v in params.items()]))
with open('train.log', 'a') as f:
f.write(f'{max_iters},{learning_rate}\n')