import argparse import torch import torch.nn as nn from torch.nn import functional as F from gpt_p.model import DecoderTransformer from datasets import load_dataset torch.manual_seed(420) # 1337 base_name = 'gpt-p_CHARS_CHAT_' device = 'cuda' if torch.cuda.is_available() else 'cpu' context_size = 256 # how many tokens to consider while generating the next batch_size = 128 # how many independent sequences will we process in parallel max_iters = 30_000 learning_rate = 3e-5 eval_interval = 100 eval_iters = 20 # number evaluation iterations n_embed = 384 # embedding size n_layer = 6 # number of transformer layers n_head = 6 dropout = 0.2 # dropout factor dataset = load_dataset('Lichess/standard-chess-games', split='train') content = '\n'.join(list(filter(lambda x: 'eval' not in x, dataset['movetext']))) ## BUILD DATA SET ## book = content characters = sorted(list(set(book))) vocab_size = len(characters) # convert stoi = {ch: idx for idx, ch in enumerate(characters)} itos = {idx: ch for idx, ch in enumerate(characters)} encode = lambda s: [stoi[c] for c in s] decode = lambda i: ''.join([itos[x] for x in i]) data = torch.tensor(encode(book), dtype=torch.long) n = int(0.9 * len(data)) train_data = data[:n] val_data = data[n:] def get_batch(split): data = train_data if split == 'train' else val_data idx = torch.randint(len(data) - context_size, (batch_size,)) x = torch.stack([data[i:i+context_size] for i in idx]) y = torch.stack([data[i+1:i+context_size+1] for i in idx]) return x.to(device), y.to(device) ## END BUILD DATA SET ## ## MODEL DEFINITION ## def print_sample(input_value=None): if input_value is None: input_value = torch.zeros((1,1), dtype=torch.long, device=device) print('Validation sample:') sample = decode(model.generate(input_value, max_new_tokens=250, context_size=context_size)[0].tolist()) if '' in sample: sample = sample[:sample.find('') + 3] print(sample) @torch.no_grad() def estimate_loss(): out = {} model.eval() for split in ['train', 'val']: losses = torch.zeros(eval_iters) for k in range(eval_iters): X, Y = get_batch(split) logits, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean() input_string = '1. e4 g6' print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string)))) model.train() return out if __name__ == "__main__": args = argparse.ArgumentParser() args.add_argument('--load', '-l', action='store_true', default=False, help='Load model state.') args.add_argument('--inference', '-i', action='store_true', default=False, help='Run only inference') args = args.parse_args() params = {'vocab_size': vocab_size, 'n_embed': n_embed, 'context_size': context_size, 'n_layer': n_layer, 'n_head': n_head, 'dropout': dropout} if args.load: m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) m.load_state_dict(torch.load(f'./models/{base_name}' + ''.join(f'{key}={v}' for key, v in params.items()))) else: m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout) model = m.to(device) if args.inference: exit() ## END MODEL ## ## START TRAINING ## optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) for step in range(max_iters): if step % eval_interval == 0: losses = estimate_loss() print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}') xb, yb = get_batch('train') logits, loss = model(xb, yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() print() print('Loss:') print(loss.item()) ## END TRAINING ## ## START VALIDATION ## ## END VALIDATION ## # save model weights torch.save(model.state_dict(), f'./models/{base_name}' + ''.join([f'{key}={v}' for key, v in params.items()])) with open('train.log', 'a') as f: f.write(f'{max_iters},{learning_rate}\n')