File size: 4,187 Bytes
12dd66e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import argparse
import torch
import torch.nn as nn
from torch.nn import functional as F
from gpt_p.model import DecoderTransformer
from datasets import load_dataset
torch.manual_seed(420) # 1337
base_name = 'gpt-p_CHARS_CHAT_'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
context_size = 256 # how many tokens to consider while generating the next
batch_size = 128 # how many independent sequences will we process in parallel
max_iters = 30_000
learning_rate = 3e-5
eval_interval = 100
eval_iters = 20 # number evaluation iterations
n_embed = 384 # embedding size
n_layer = 6 # number of transformer layers
n_head = 6
dropout = 0.2 # dropout factor
dataset = load_dataset('Lichess/standard-chess-games', split='train')
content = '\n'.join(list(filter(lambda x: 'eval' not in x, dataset['movetext'])))
## BUILD DATA SET ##
book = content
characters = sorted(list(set(book)))
vocab_size = len(characters)
# convert
stoi = {ch: idx for idx, ch in enumerate(characters)}
itos = {idx: ch for idx, ch in enumerate(characters)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda i: ''.join([itos[x] for x in i])
data = torch.tensor(encode(book), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
def get_batch(split):
data = train_data if split == 'train' else val_data
idx = torch.randint(len(data) - context_size, (batch_size,))
x = torch.stack([data[i:i+context_size] for i in idx])
y = torch.stack([data[i+1:i+context_size+1] for i in idx])
return x.to(device), y.to(device)
## END BUILD DATA SET ##
## MODEL DEFINITION ##
def print_sample(input_value=None):
if input_value is None:
input_value = torch.zeros((1,1), dtype=torch.long, device=device)
print('Validation sample:')
sample = decode(model.generate(input_value, max_new_tokens=250, context_size=context_size)[0].tolist())
if '<E>' in sample:
sample = sample[:sample.find('<E>') + 3]
print(sample)
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
input_string = '1. e4 g6'
print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string))))
model.train()
return out
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument('--load', '-l', action='store_true', default=False, help='Load model state.')
args.add_argument('--inference', '-i', action='store_true', default=False, help='Run only inference')
args = args.parse_args()
params = {'vocab_size': vocab_size, 'n_embed': n_embed, 'context_size': context_size, 'n_layer': n_layer, 'n_head': n_head, 'dropout': dropout}
if args.load:
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
m.load_state_dict(torch.load(f'./models/{base_name}' + ''.join(f'{key}={v}' for key, v in params.items())))
else:
m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model = m.to(device)
if args.inference:
exit()
## END MODEL ##
## START TRAINING ##
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for step in range(max_iters):
if step % eval_interval == 0:
losses = estimate_loss()
print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}')
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
print()
print('Loss:')
print(loss.item())
## END TRAINING ##
## START VALIDATION ##
## END VALIDATION ##
# save model weights
torch.save(model.state_dict(), f'./models/{base_name}' + ''.join([f'{key}={v}' for key, v in params.items()]))
with open('train.log', 'a') as f:
f.write(f'{max_iters},{learning_rate}\n')
|