LightGPT / instruction-tune.py
Andrew DalPino
Use SmolTalk dataset for instruction-tuning
111c2a3
raw
history blame
6.62 kB
import random
from argparse import ArgumentParser
import torch
from torch.utils.data import DataLoader
from torch.optim import Adafactor
from torch.amp import autocast
from torch.cuda import is_available as cuda_is_available, is_bf16_supported
from torch.utils.data import random_split
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.text import Perplexity
from model import LightGPT, LightGPTInstruct
from data import SmolTalk
import tiktoken
from tqdm import tqdm
def main():
parser = ArgumentParser(description="Instruction-tune the GPT.")
parser.add_argument(
"--base_model_path", default="./checkpoints/checkpoint.pt", type=str
)
parser.add_argument("--max_tokens_per_sample", default=1048, type=int)
parser.add_argument("--mask_input", action="store_true")
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
parser.add_argument("--learning_rate", default=5e-4, type=float)
parser.add_argument("--rms_decay", default=-0.8, type=float)
parser.add_argument("--optimizer_low_memory", action="store_true")
parser.add_argument("--num_epochs", default=4, type=int)
parser.add_argument("--rank", default=8, type=int)
parser.add_argument("--alpha", default=1.0, type=float)
parser.add_argument("--dropout", default=0.05, type=float)
parser.add_argument("--activation_checkpointing", action="store_true")
parser.add_argument("--eval_interval", default=1, type=int)
parser.add_argument("--checkpoint_interval", default=1, type=int)
parser.add_argument(
"--checkpoint_path", default="./checkpoints/lora_instruction.pt", type=str
)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--run_dir_path", default="./runs/instruction-tune", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--seed", default=None, type=int)
args = parser.parse_args()
if "cuda" in args.device and not cuda_is_available():
raise RuntimeError("Cuda is not available.")
torch.set_float32_matmul_precision("high")
dtype = (
torch.bfloat16
if "cuda" in args.device and is_bf16_supported()
else torch.float32
)
amp_context = autocast(device_type=args.device, dtype=dtype)
if args.seed:
torch.manual_seed(args.seed)
random.seed(args.seed)
logger = SummaryWriter(args.run_dir_path)
checkpoint = torch.load(
args.base_model_path, map_location=args.device, weights_only=True
)
model_args = checkpoint["model_args"]
tokenizer = tiktoken.get_encoding(checkpoint["token_encoding"])
dataset = SmolTalk(
tokenizer,
subset="all",
max_tokens_per_sample=args.max_tokens_per_sample,
)
training, testing = random_split(dataset, (0.9, 0.1))
train_loader = DataLoader(
training,
collate_fn=dataset.collate,
batch_size=args.batch_size,
pin_memory="cpu" not in args.device,
shuffle=True,
)
test_loader = DataLoader(
testing,
collate_fn=dataset.collate,
batch_size=args.batch_size,
pin_memory="cpu" not in args.device,
shuffle=False,
)
model = LightGPT(**model_args)
if args.activation_checkpointing:
model.enable_activation_checkpointing()
model = torch.compile(model)
model.load_state_dict(checkpoint["model"])
print("Model checkpoint loaded")
lora_args = {
"rank": args.rank,
"alpha": args.alpha,
"dropout": args.dropout,
}
model = LightGPTInstruct(model, **lora_args).to(args.device)
print("Compiling model")
model.compile()
optimizer = Adafactor(
model.parameters(),
lr=args.learning_rate,
beta2_decay=args.rms_decay,
foreach=not args.optimizer_low_memory,
)
starting_epoch = 1
if args.resume:
checkpoint = torch.load(
args.checkpoint_path, map_location=args.device, weights_only=True
)
model.load_state_dict(checkpoint["lora"], strict=False)
optimizer.load_state_dict(checkpoint["optimizer"])
starting_epoch += checkpoint["epoch"]
print("Previous checkpoint resumed successfully")
model.train()
print(f"Model has {model.num_trainable_params:,} trainable parameters")
perplexity_metric = Perplexity(ignore_index=dataset.PADDING_INDEX).to(args.device)
print("Instruction-tuning ...")
for epoch in range(starting_epoch, args.num_epochs + 1):
total_cross_entropy, total_batches = 0.0, 0
for step, (x, y) in enumerate(
tqdm(train_loader, desc=f"Epoch {epoch}", leave=False), start=1
):
x = x.to(args.device, non_blocking=True)
y = y.to(args.device, non_blocking=True)
with amp_context:
_, loss = model(x, y)
scaled_loss = loss / args.gradient_accumulation_steps
scaled_loss.backward()
total_cross_entropy += loss.item()
if step % args.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
total_batches += 1
average_cross_entropy = total_cross_entropy / total_batches
logger.add_scalar("cross entropy", average_cross_entropy, epoch)
print(
f"Epoch {epoch}: Cross Entropy: {average_cross_entropy:.5f}",
)
if epoch % args.eval_interval == 0:
model.eval()
for x, y in tqdm(test_loader, desc="Testing", leave=False):
x = x.to(args.device, non_blocking=True)
y = y.to(args.device, non_blocking=True)
with torch.no_grad():
y_pred, _ = model(x)
perplexity_metric.update(y_pred, y)
perplexity = perplexity_metric.compute()
logger.add_scalar("perplexity", perplexity, epoch)
print(f"Perplexity: {perplexity:.3f}")
perplexity_metric.reset()
model.train()
if epoch % args.checkpoint_interval == 0:
checkpoint = {
"epoch": epoch,
"lora_args": lora_args,
"lora": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, args.checkpoint_path)
print("Checkpoint saved")
print("Done!")
if __name__ == "__main__":
main()