# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb. # %% auto 0 __all__ = [] # %% ../nbs/B2. Training (Lightning).ipynb 2 import io import time import random from pathlib import Path from fastprogress import progress_bar, master_bar import fastprogress import wandb import numpy as np import pylab as plt import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader from torch.profiler import record_function # %% ../nbs/B2. Training (Lightning).ipynb 3 import lightning.pytorch as pl import math class TrainingTask(pl.LightningModule): def __init__(self, model, model_hparams=None): super().__init__() self.model = model self.model_hparams = model_hparams def on_fit_start(self): if getattr(self.model, 'setup'): self.model.setup(self.device) def configure_optimizers(self): """ Initialize AdamW optimizer""" lr = self.model_hparams['lr0'] weight_decay = self.model_hparams['weight_decay'] all_params = set(model.parameters()) customized_params = set() groups = [] group_map = {} for name,m in model.named_modules(): if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'): customized_params |= set(m.parameters()) m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay m_lr = lr * getattr(m, 'lr_scale', 1) group = group_map.get((m_wd, m_lr), None) if not group: group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr} groups.append(group) group_map[(m_wd, m_lr)] = group group['params'] += m.parameters() group['names'].append(name) other_params = all_params - customized_params param_groups = groups + [ {"names": ["other"], "params": list(other_params), "weight_decay": weight_decay }, ] optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups) # modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319 def num_steps_per_epoch() -> int: """Get number of steps""" # Accessing _data_source is flaky and might break dataset = self.trainer.fit_loop._data_source.dataloader() dataset_size = len(dataset) # math.ceil so always overestimate (underestimating throws exceptions) num_steps = math.ceil(dataset_size / self.trainer.accumulate_grad_batches) return num_steps total_steps = self.model_hparams['epochs'] * num_steps_per_epoch() self.model_hparams['pct_start'] = min(0.3, self.model_hparams['warmup_steps'] / total_steps) print(f"{self.model_hparams['epochs']=} epochs x {num_steps_per_epoch()=} steps") lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, pct_start=self.model_hparams['pct_start'], max_lr=[pg.get('lr', lr) for pg in param_groups], steps_per_epoch=num_steps_per_epoch(), epochs=int(self.model_hparams['epochs']), final_div_factor=25 ) return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}] def training_step(self, train_batch, batch_idx): train_logits, train_loss = self.model.forward(*train_batch) self.log("train_loss", train_loss, sync_dist=True) return train_loss def validation_step(self, val_batch, batch_idx): val_logits, val_loss = self.model.forward(*val_batch) self.log("val_loss", val_loss, sync_dist=True) return val_loss def on_validation_epoch_end(self): if hasattr(self.model, 'get_metrics'): self.log_dict({'metrics/'+k:v for k,v in self.model.get_metrics().items()}, sync_dist=True) def test_step(self, val_batch, batch_idx): test_logits, test_loss = self.model.forward(*val_batch) self.log("test_loss", test_loss, sync_dist=True) return test_loss # %% ../nbs/B2. Training (Lightning).ipynb 4 from fastcore.script import anno_parser import shlex # watch out: we can only pass Python values as keyword arguments (not positional) # everything else has to be a string def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True): p = anno_parser(fun) args = p.parse_args(args).__dict__ args.pop('xtra'); args.pop('pdb') args.update({k:v for k, v in kwargs.items()}) if log_to_wandb and type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: wandb_logger.experiment.config[name] = {k:v for k,v in args.items() if k not in ['dataset', 'tunables']} return fun(**args) # %% ../nbs/B2. Training (Lightning).ipynb 8 import argparse parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, help='Task to train') parser.add_argument('--seed', type=int, default=0, help='Global training seed') parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') parser.add_argument('--input-dir', type=str, default='', help='input data path') # fixed in the model for now parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints") parser.add_argument('--epochs', type=int, default=10, help='total training epochs') parser.add_argument('--validate-every-n-steps', type=int, default=500, help='how training steps to run between validations') parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay') parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate') parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping') parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='perform the optimizer step only after going through several batches of samples') parser.add_argument('--precision', type=str, default="16-mixed", help="floating point precision") parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)') parser.add_argument('--tunables', type=str, default="", help='tunable hyperparameters') parser.add_argument('--resume-from', type=Path, default=None, help='resume training from the given checkpoint') parser.add_argument('--strategy', type=str, default='ddp', help='distributed training strategy') parser.add_argument('--wandb-suffix', type=str, default=None, help='W&B project name suffix') parser.add_argument('--wandb-task-name', type=str, default=None, help='Task name for the W&B project name') args = parser.parse_args().__dict__ task_args: list = shlex.split(args.pop("task")) task_name, task_args = task_args[0], task_args[1:] input_args: list = shlex.split(args.pop("input_dir")) checkpoint_dir: str = args.pop("checkpoint_dir") num_workers: int = args.pop("workers") batch_size: int = args.pop("batch_size") epochs: int = args.pop("epochs") tunables_args: list = shlex.split(args.pop("tunables")) hyp_params = {} hyp_params['batch_size'] = batch_size hyp_params['warmup_steps'] = args['warmup_steps'] hyp_params['weight_decay'] = args['weight_decay'] hyp_params['clip_gradient_norm'] = args['clip_gradient_norm'] hyp_params['accumulate_grad_batches'] = args['accumulate_grad_batches'] hyp_params['precision'] = args['precision'] hyp_params['lr0'] = args['lr0'] hyp_params['epochs'] = epochs hyp_params['strategy'] = args['strategy'] # %% ../nbs/B2. Training (Lightning).ipynb 9 from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.callbacks import LearningRateMonitor import datetime import webdataset as wds import importlib torch.set_float32_matmul_precision('medium') project = f"WhisperSpeech-{args['wandb_task_name'] or task_name}" if args['wandb_suffix']: project += "-"+args['wandb_suffix'] wandb_logger = WandbLogger(project=project) ckpt_callback = pl.callbacks.ModelCheckpoint( dirpath=f'{task_name}-{epochs}e', filename=task_name+"-{epoch}-{step}-{val_loss:.2f}", monitor="val_loss", save_top_k=4, train_time_interval=datetime.timedelta(minutes=5), ) lr_monitor_callback = LearningRateMonitor(logging_interval='step') from torch.utils.data import DataLoader task = importlib.import_module("whisperspeech."+task_name) train_ds, val_ds = parse_and_call('dataset', task.load_datasets, input_args) tunables = None if hasattr(task, "Tunables"): import dataclasses tunables = parse_and_call('tunables', task.Tunables, tunables_args, log_to_wandb=False) if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: wandb_logger.experiment.config['tunables'] = dataclasses.asdict(tunables) for name in ["lr0", "clip_gradient_norm", "weight_decay", "warmup_steps"]: val = getattr(tunables, name, None) if val is not None: hyp_params[name] = val if isinstance(train_ds, torch.utils.data.IterableDataset): dl_batch_size, dl_shuffle = None, False pin_memory = False else: dl_batch_size, dl_shuffle = batch_size, True pin_memory = True val_loader = wds.WebLoader(val_ds, batch_size=dl_batch_size, num_workers=num_workers, drop_last=False, pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(val_ds.total_samples // batch_size) train_loader = wds.WebLoader(train_ds, batch_size=dl_batch_size, num_workers=num_workers, drop_last=False, shuffle=dl_shuffle, pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(train_ds.total_samples // batch_size) model_kwargs = dict(dataset=train_ds) if tunables is not None: model_kwargs['tunables'] = tunables model = parse_and_call('model', task.make_model, task_args, model_kwargs) task = TrainingTask(model, model_hparams=hyp_params) trainer = pl.Trainer(strategy=hyp_params['strategy'], max_epochs=hyp_params['epochs'], accelerator="gpu", profiler="simple", precision=hyp_params['precision'], gradient_clip_val=hyp_params['clip_gradient_norm'], accumulate_grad_batches=hyp_params['accumulate_grad_batches'], val_check_interval=args.pop("validate_every_n_steps"), enable_checkpointing=True, logger=wandb_logger, callbacks=[ckpt_callback, lr_monitor_callback]) if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: wandb_logger.experiment.config.update(hyp_params) kwargs = {} if 'resume_from' in args: kwargs['ckpt_path'] = args['resume_from'] trainer.fit(model=task, train_dataloaders=train_loader, val_dataloaders=val_loader, **kwargs)