|
from dataclasses import dataclass |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from model import GPT, GPTConfig |
|
import tiktoken |
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler |
|
import math |
|
import matplotlib.pyplot as plt |
|
from torch.distributed import init_process_group, destroy_process_group |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
import torch.distributed as dist |
|
import os |
|
|
|
|
|
import signal |
|
import sys |
|
|
|
def signal_handler(sig, frame): |
|
print('Gracefully stopping the training process') |
|
destroy_process_group() |
|
sys.exit(0) |
|
|
|
signal.signal(signal.SIGINT, signal_handler) |
|
|
|
torch.manual_seed(1337) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(1337) |
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
|
|
print("Using device:", device) |
|
|
|
|
|
|
|
|
|
enc = tiktoken.get_encoding('gpt2') |
|
|
|
|
|
lossi = [] |
|
val_lossi = [] |
|
|
|
|
|
|
|
|
|
with open("tinyshakespeare.txt", "r") as f: |
|
text = f.read() |
|
tokens = enc.encode(text) |
|
print(f"Number of tokens: {len(tokens):,}") |
|
|
|
|
|
|
|
|
|
ddp = int(os.environ.get('RANK', -1)) != -1 |
|
if ddp: |
|
|
|
assert torch.cuda.is_available(), "for now i think we need CUDA for DDP" |
|
init_process_group(backend='nccl') |
|
ddp_rank = int(os.environ['RANK']) |
|
ddp_local_rank = int(os.environ['LOCAL_RANK']) |
|
ddp_world_size = int(os.environ['WORLD_SIZE']) |
|
device = f'cuda:{ddp_local_rank}' |
|
torch.cuda.set_device(device) |
|
|
|
master_process = ddp_rank == 0 |
|
else: |
|
|
|
ddp_rank = 0 |
|
ddp_local_rank = 0 |
|
ddp_world_size = 1 |
|
master_process = True |
|
|
|
if master_process: |
|
print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}, master_process: {master_process}") |
|
|
|
|
|
|
|
|
|
|
|
gpt = GPT(GPTConfig(vocab_size=50304), master_process).to(device) |
|
if device == torch.device("cuda"): |
|
gpt.compile() |
|
if ddp: |
|
gpt = DDP(gpt, device_ids=[ddp_local_rank]) |
|
|
|
raw_gpt = gpt.module if ddp else gpt |
|
|
|
|
|
|
|
|
|
from torch.utils.data import Subset |
|
|
|
class ShakespeareDataset(Dataset): |
|
def __init__(self, tokens, seq_len): |
|
self.tokens = tokens |
|
self.seq_len = seq_len |
|
|
|
def __len__(self): |
|
return len(self.tokens) - self.seq_len - 1 |
|
|
|
def __getitem__(self, idx): |
|
x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long) |
|
y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long) |
|
return x, y |
|
|
|
|
|
def split_dataset(dataset, val_ratio=0.0005): |
|
dataset_size = len(dataset) |
|
indices = list(range(dataset_size)) |
|
split = int(val_ratio * dataset_size) |
|
|
|
train_indices, val_indices = indices[split:], indices[:split] |
|
train_dataset = Subset(dataset, train_indices) |
|
val_dataset = Subset(dataset, val_indices) |
|
|
|
return train_dataset, val_dataset |
|
|
|
T = 8 |
|
batch_size = 4 |
|
total_batch_size = 2**8 |
|
assert total_batch_size % (T*batch_size*ddp_world_size) == 0, "Batch size is not divisible by B*T" |
|
grad_accum_steps = total_batch_size // (T*batch_size*ddp_world_size) |
|
|
|
if master_process: |
|
print("Total desired batch size: {:,}".format(total_batch_size)) |
|
print("gradient accumulation steps: {:,}".format(grad_accum_steps)) |
|
|
|
dataset = ShakespeareDataset(tokens, T) |
|
train_dataset, val_dataset = split_dataset(dataset) |
|
|
|
if ddp: |
|
train_sampler = DistributedSampler(train_dataset) |
|
val_sampler = DistributedSampler(val_dataset) |
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) |
|
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) |
|
else: |
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) |
|
|
|
if master_process: |
|
print(f"The training dataloader has {len(train_dataloader):,} individual batches") |
|
print(f"The validation dataloader has {len(val_dataloader):,} individual batches") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text(seed_text, model, enc, max_len=100, print_while_generating=True): |
|
model.eval() |
|
with torch.no_grad(): |
|
tokens = enc.encode(seed_text) |
|
for _ in range(max_len): |
|
x = torch.tensor(tokens[-T:], dtype=torch.long, |
|
device=device).unsqueeze(0) |
|
logits, _ = model(x) |
|
next_token = torch.argmax(logits[:, -1, :]) |
|
tokens.append(int(next_token)) |
|
|
|
if print_while_generating: |
|
print(enc.decode([int(next_token)]), end="") |
|
print() |
|
|
|
return enc.decode(tokens) |
|
|
|
|
|
|
|
|
|
|
|
if ddp: |
|
optimizer = raw_gpt.configure_optimizers( |
|
weight_decay=0.1, learning_rate=6e-4, device=device) |
|
else: |
|
optimizer = gpt.configure_optimizers( |
|
weight_decay=0.1, learning_rate=6e-4, device=device) |
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
|
|
max_lr = 6e-4 |
|
min_lr = max_lr * 0.1 |
|
warmup_steps = 10 |
|
max_steps = 20000 |
|
|
|
|
|
def get_lr(step): |
|
if step < warmup_steps: |
|
return max_lr * (step+1) / warmup_steps |
|
if step > max_steps: |
|
return min_lr |
|
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) |
|
assert 0 <= decay_ratio <= 1 |
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
return min_lr + coeff * (max_lr - min_lr) |
|
|
|
|
|
|
|
supports_bfloat16 = False |
|
if device == "cuda": |
|
capability = torch.cuda.get_device_capability() |
|
if capability[0] >= 8 and capability[1] >= 0: |
|
supports_bfloat16 = True |
|
|
|
|
|
|
|
|
|
generate_every = 50 |
|
validate_every = 5 |
|
for step in range(max_steps): |
|
gpt.zero_grad() |
|
loss_accum = 0.0 |
|
for minibatchstep in range(grad_accum_steps): |
|
x, y = next(iter(train_dataloader)) |
|
x, y = x.to(device), y.to(device) |
|
|
|
if supports_bfloat16: |
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
|
logits, loss = gpt(x, y) |
|
else: |
|
logits, loss = gpt(x, y) |
|
|
|
loss = loss / grad_accum_steps |
|
loss_accum += loss.detach() |
|
if ddp: |
|
gpt.require_backward_grad_sync = (minibatchstep == grad_accum_steps - 1) |
|
loss.backward() |
|
|
|
if ddp: |
|
dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG) |
|
lossi.append(loss_accum.item()) |
|
norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0) |
|
lr = get_lr(step) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
optimizer.step() |
|
|
|
if master_process: |
|
print(f'Step {step}, Loss: {loss_accum}, Norm: {norm}') |
|
|
|
if step % generate_every == 0 and master_process: |
|
print(generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False)) |
|
|
|
|
|
if step % validate_every == 0: |
|
if master_process: |
|
print("Validating...") |
|
gpt.eval() |
|
val_loss_accum = 0.0 |
|
with torch.no_grad(): |
|
for val_x, val_y in val_dataloader: |
|
val_x, val_y = val_x.to(device), val_y.to(device) |
|
if supports_bfloat16: |
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
|
val_logits, val_loss = gpt(val_x, val_y) |
|
else: |
|
val_logits, val_loss = gpt(val_x, val_y) |
|
|
|
val_loss_accum += val_loss.detach() |
|
val_lossi.append(val_loss_accum.item()) |
|
if ddp: |
|
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) |
|
val_loss_avg = val_loss_accum / len(val_dataloader) |
|
if master_process: |
|
print(f'Validation Loss: {val_loss_avg}') |
|
gpt.train() |
|
|
|
|
|
|
|
|
|
if master_process: |
|
plt.plot(lossi) |
|
plt.show() |
|
|
|
|
|
if master_process: |
|
generate_text("The king said", gpt, enc, max_len=25) |
|
|
|
|
|
|
|
|
|
if master_process: |
|
torch.save(gpt.state_dict(), "gpt2_shakespeare.pth") |
|
torch.save(torch.tensor(lossi), "lossi.pth") |
|
|
|
|
|
|
|
|
|
if ddp: |
|
destroy_process_group() |
|
|
|
import sys; sys.exit(0) |
|
|