gpt-project / gpt-2 /training_shakespeare.py
mnmnmnmn's picture
Upload 15 files
7fc0f78 verified
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import GPT, GPTConfig
import tiktoken
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import math
import matplotlib.pyplot as plt
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os
import signal
import sys
def signal_handler(sig, frame):
print('Gracefully stopping the training process')
destroy_process_group()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
torch.manual_seed(1337)
if torch.cuda.is_available():
torch.cuda.manual_seed(1337)
# ***************************#
# Device Configuration
# ***************************#
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
print("Using device:", device)
# ***************************#
# Tokenizer Setup
# ***************************#
enc = tiktoken.get_encoding('gpt2')
lossi = []
val_lossi = []
# ***************************#
# Load Text Data
# ***************************#
with open("tinyshakespeare.txt", "r") as f:
text = f.read()
tokens = enc.encode(text)
print(f"Number of tokens: {len(tokens):,}")
# ***************************#
# Set up DDP
# ***************************#
# torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
# use of DDP atm demands CUDA, we set the device appropriately according to rank
assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
# this process will do logging, checkpointing etc.
master_process = ddp_rank == 0
else:
# vanilla, non-DDP run
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
if master_process:
print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}, master_process: {master_process}")
# ***************************#
# Model Configuration
# ***************************#
gpt = GPT(GPTConfig(vocab_size=50304), master_process).to(device)
if device == torch.device("cuda"):
gpt.compile()
if ddp:
gpt = DDP(gpt, device_ids=[ddp_local_rank])
raw_gpt = gpt.module if ddp else gpt
# ***************************#
# Dataset and Dataloader
# ***************************#
from torch.utils.data import Subset
class ShakespeareDataset(Dataset):
def __init__(self, tokens, seq_len):
self.tokens = tokens
self.seq_len = seq_len
def __len__(self):
return len(self.tokens) - self.seq_len - 1
def __getitem__(self, idx):
x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
return x, y
# Split the dataset into training and validation sets
def split_dataset(dataset, val_ratio=0.0005):
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(val_ratio * dataset_size)
train_indices, val_indices = indices[split:], indices[:split]
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
return train_dataset, val_dataset
T = 8
batch_size = 4
total_batch_size = 2**8 # 524,288 = 2**19, in number of tokens
assert total_batch_size % (T*batch_size*ddp_world_size) == 0, "Batch size is not divisible by B*T"
grad_accum_steps = total_batch_size // (T*batch_size*ddp_world_size)
if master_process:
print("Total desired batch size: {:,}".format(total_batch_size))
print("gradient accumulation steps: {:,}".format(grad_accum_steps))
dataset = ShakespeareDataset(tokens, T)
train_dataset, val_dataset = split_dataset(dataset)
if ddp:
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(val_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler)
else:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
if master_process:
print(f"The training dataloader has {len(train_dataloader):,} individual batches")
print(f"The validation dataloader has {len(val_dataloader):,} individual batches")
# ***************************#
# Text Generation Function
# ***************************#
def generate_text(seed_text, model, enc, max_len=100, print_while_generating=True):
model.eval()
with torch.no_grad():
tokens = enc.encode(seed_text)
for _ in range(max_len):
x = torch.tensor(tokens[-T:], dtype=torch.long,
device=device).unsqueeze(0)
logits, _ = model(x)
next_token = torch.argmax(logits[:, -1, :])
tokens.append(int(next_token))
if print_while_generating:
print(enc.decode([int(next_token)]), end="")
print()
return enc.decode(tokens)
# ***************************#
# Optimizer Configuration
# ***************************#
if ddp:
optimizer = raw_gpt.configure_optimizers(
weight_decay=0.1, learning_rate=6e-4, device=device)
else:
optimizer = gpt.configure_optimizers(
weight_decay=0.1, learning_rate=6e-4, device=device)
torch.set_float32_matmul_precision('high')
# ***************************#
# Learning Rate Scheduler
# ***************************#
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 20000
def get_lr(step):
if step < warmup_steps:
return max_lr * (step+1) / warmup_steps
if step > max_steps:
return min_lr
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
# Check if the device supports bfloat16
supports_bfloat16 = False
if device == "cuda":
capability = torch.cuda.get_device_capability()
if capability[0] >= 8 and capability[1] >= 0:
supports_bfloat16 = True
# ***************************#
# Training Loop
# ***************************#
generate_every = 50
validate_every = 5
for step in range(max_steps):
gpt.zero_grad()
loss_accum = 0.0
for minibatchstep in range(grad_accum_steps):
x, y = next(iter(train_dataloader))
x, y = x.to(device), y.to(device)
if supports_bfloat16:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
logits, loss = gpt(x, y)
else:
logits, loss = gpt(x, y)
loss = loss / grad_accum_steps
loss_accum += loss.detach()
if ddp:
gpt.require_backward_grad_sync = (minibatchstep == grad_accum_steps - 1)
loss.backward()
if ddp:
dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
lossi.append(loss_accum.item())
norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0)
lr = get_lr(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
if master_process:
print(f'Step {step}, Loss: {loss_accum}, Norm: {norm}')
if step % generate_every == 0 and master_process:
print(generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False))
# Validation step
if step % validate_every == 0:
if master_process:
print("Validating...")
gpt.eval()
val_loss_accum = 0.0
with torch.no_grad():
for val_x, val_y in val_dataloader:
val_x, val_y = val_x.to(device), val_y.to(device)
if supports_bfloat16:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
val_logits, val_loss = gpt(val_x, val_y)
else:
val_logits, val_loss = gpt(val_x, val_y)
val_loss_accum += val_loss.detach()
val_lossi.append(val_loss_accum.item())
if ddp:
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
val_loss_avg = val_loss_accum / len(val_dataloader)
if master_process:
print(f'Validation Loss: {val_loss_avg}')
gpt.train()
# ***************************#
# Plot Loss
# ***************************#
if master_process:
plt.plot(lossi)
plt.show()
# Generate Final Text
if master_process:
generate_text("The king said", gpt, enc, max_len=25)
# ***************************#
# Save Model and Loss
# ***************************#
if master_process:
torch.save(gpt.state_dict(), "gpt2_shakespeare.pth")
torch.save(torch.tensor(lossi), "lossi.pth")
# ***************************#
# Cleanup
# ***************************#
if ddp:
destroy_process_group()
import sys; sys.exit(0)