|
import argparse |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from gpt_p.model import DecoderTransformer |
|
from datasets import load_dataset |
|
|
|
|
|
torch.manual_seed(420) |
|
|
|
base_name = 'gpt-p_CHARS_CHAT_' |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
context_size = 256 |
|
batch_size = 128 |
|
max_iters = 30_000 |
|
learning_rate = 3e-5 |
|
eval_interval = 100 |
|
eval_iters = 20 |
|
n_embed = 384 |
|
n_layer = 6 |
|
n_head = 6 |
|
dropout = 0.2 |
|
|
|
dataset = load_dataset('Lichess/standard-chess-games', split='train') |
|
content = '\n'.join(list(filter(lambda x: 'eval' not in x, dataset['movetext']))) |
|
|
|
|
|
book = content |
|
characters = sorted(list(set(book))) |
|
vocab_size = len(characters) |
|
|
|
|
|
stoi = {ch: idx for idx, ch in enumerate(characters)} |
|
itos = {idx: ch for idx, ch in enumerate(characters)} |
|
|
|
encode = lambda s: [stoi[c] for c in s] |
|
decode = lambda i: ''.join([itos[x] for x in i]) |
|
|
|
|
|
data = torch.tensor(encode(book), dtype=torch.long) |
|
n = int(0.9 * len(data)) |
|
train_data = data[:n] |
|
val_data = data[n:] |
|
|
|
|
|
def get_batch(split): |
|
data = train_data if split == 'train' else val_data |
|
idx = torch.randint(len(data) - context_size, (batch_size,)) |
|
x = torch.stack([data[i:i+context_size] for i in idx]) |
|
y = torch.stack([data[i+1:i+context_size+1] for i in idx]) |
|
return x.to(device), y.to(device) |
|
|
|
|
|
|
|
|
|
def print_sample(input_value=None): |
|
if input_value is None: |
|
input_value = torch.zeros((1,1), dtype=torch.long, device=device) |
|
print('Validation sample:') |
|
sample = decode(model.generate(input_value, max_new_tokens=250, context_size=context_size)[0].tolist()) |
|
if '<E>' in sample: |
|
sample = sample[:sample.find('<E>') + 3] |
|
print(sample) |
|
|
|
|
|
@torch.no_grad() |
|
def estimate_loss(): |
|
out = {} |
|
model.eval() |
|
for split in ['train', 'val']: |
|
losses = torch.zeros(eval_iters) |
|
for k in range(eval_iters): |
|
X, Y = get_batch(split) |
|
logits, loss = model(X, Y) |
|
losses[k] = loss.item() |
|
out[split] = losses.mean() |
|
|
|
input_string = '1. e4 g6' |
|
print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string)))) |
|
model.train() |
|
return out |
|
|
|
|
|
if __name__ == "__main__": |
|
args = argparse.ArgumentParser() |
|
args.add_argument('--load', '-l', action='store_true', default=False, help='Load model state.') |
|
args.add_argument('--inference', '-i', action='store_true', default=False, help='Run only inference') |
|
|
|
args = args.parse_args() |
|
|
|
params = {'vocab_size': vocab_size, 'n_embed': n_embed, 'context_size': context_size, 'n_layer': n_layer, 'n_head': n_head, 'dropout': dropout} |
|
if args.load: |
|
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) |
|
m.load_state_dict(torch.load(f'./models/{base_name}' + ''.join(f'{key}={v}' for key, v in params.items()))) |
|
else: |
|
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) |
|
model = m.to(device) |
|
|
|
if args.inference: |
|
exit() |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
|
|
|
for step in range(max_iters): |
|
if step % eval_interval == 0: |
|
losses = estimate_loss() |
|
print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}') |
|
|
|
xb, yb = get_batch('train') |
|
|
|
logits, loss = model(xb, yb) |
|
optimizer.zero_grad(set_to_none=True) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print() |
|
print('Loss:') |
|
print(loss.item()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), f'./models/{base_name}' + ''.join([f'{key}={v}' for key, v in params.items()])) |
|
with open('train.log', 'a') as f: |
|
f.write(f'{max_iters},{learning_rate}\n') |
|
|