# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb. # %% auto 0 __all__ = ['SimpleVisual', 'validate', 'train'] # %% ../nbs/B1. Training.ipynb 2 import io import time import random from pathlib import Path from fastprogress import progress_bar, master_bar import fastprogress import numpy as np import pylab as plt import math import IPython import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader from torch.profiler import record_function import webdataset as wds torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True torch.backends.cuda.matmul.allow_tf32 = True torch.set_float32_matmul_precision('medium') # %% ../nbs/B1. Training.ipynb 3 class SimpleVisual: def __init__ (self, model, masterbar, total_steps): self.model = model self.masterbar = masterbar self.total_steps = total_steps self.epochs = total_steps // masterbar.main_bar.total gs = plt.GridSpec(2, 1, height_ratios=[3,1]) graph_fig = plt.figure(figsize=(10,6)) self.graph_fig = graph_fig self.loss_p = graph_fig.add_subplot(gs[0]) self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p) self.lr_p.tick_params('x', labelbottom=False) self.graph_out = None self.its = [] self.train_losses = [] self.val_losses = [] self.lr_history = [] def show(self): self.start_t = time.time() self.masterbar.write(["samples", "train", "val", "time"], table=True) self.graph_out = display(self.graph_fig, display_id=True, clear=True) def hide(self): if self.graph_out is not None: self.graph_out.update(IPython.display.HTML('')) def plot(self): loss_p, lr_p = self.loss_p, self.lr_p loss_p.clear() loss_p.plot(self.its, self.train_losses) loss_p.plot(self.its, self.val_losses) loss_p.set_xlim(0, self.total_steps) loss_p.set_yscale('log') lr_p.clear() lrs = np.array(self.lr_history) lr_p.plot(self.its, lrs) self.graph_out.update(self.graph_fig) def add_data(self, it, lr, train_loss, val_los): self.its.append(it) self.train_losses.append(train_loss) self.val_losses.append(val_los) self.lr_history.append(lr) self.plot() def add_table_row(self, it, avg_train_loss, val_loss): elapsed_t = time.time() - self.start_t self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True) def on_iter(self, bar, it, avg_train_loss, val_loss): epoch = math.ceil(it / self.total_steps * self.epochs) bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}" # %% ../nbs/B1. Training.ipynb 4 # FIXME: we need to keep this synchronised with the validation code below... def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"): if isinstance(val, torch.utils.data.IterableDataset): val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ .unbatched().shuffle(1024).batched(bs) else: val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last) with torch.no_grad(): val_loss = 0 val_samples = 0 for args in val_loader: args = [x.to(device, non_blocking=True) for x in args] with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): ps, loss = model(*args) N = args[0].shape[0] val_loss += loss.mean().item() * N val_samples += N val_loss = val_loss / val_samples return val_loss # %% ../nbs/B1. Training.ipynb 5 def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False, weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None, dl_workers=8, visual_class = SimpleVisual, profiler=None, run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None, device="cuda", trainable_params=None): if chkpt_every_iters is None: chkpt_every_iters = table_row_every_iters mb = master_bar(range(epochs)) if isinstance(train, torch.utils.data.IterableDataset): pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs))) visual = visual_class(model, mb, epochs * train.total_samples) # pct_start = min(0.3, warmup_steps / (epochs * len(train))) # visual = visual_class(model, mb, epochs*len(train)*bs) else: pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs)) visual = visual_class(model, mb, epochs*len(train)) model.visual = visual Path(checkpoint_path).mkdir(exist_ok=True) if isinstance(train, torch.utils.data.IterableDataset): # train_loader = DataLoader(train, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False, shuffle=False) # val_loader = DataLoader(val, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False) train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ .unbatched().shuffle(1024).batched(bs, partial=False) val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ .unbatched().shuffle(1024).batched(bs) else: train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True) val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last) val_loss = torch.nan avg_train_loss = torch.nan if hasattr(model, 'setup'): model.setup(device) try: scheduler = None if trainable_params is None: trainable_params = model.parameters() all_params = set(trainable_params) customized_params = set() groups = [] group_map = {} for name,m in model.named_modules(): if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'): m_trainable = [x for x in m.parameters() if x in all_params] if not m_trainable: continue customized_params |= set(m_trainable) m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay m_lr = lr * getattr(m, 'lr_scale', 1) group = group_map.get((m_wd, m_lr), None) if not group: group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr} groups.append(group) group_map[(m_wd, m_lr)] = group group['params'] += m_trainable group['names'].append(name) other_params = all_params - customized_params if other_params: groups = groups + [ {"names": ["other"], "params": list(other_params), "weight_decay": weight_decay }, ] optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups) model._optimizer = optimizer scaler = torch.cuda.amp.GradScaler(enabled=half) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs, max_lr=[pg.get('lr', lr) for pg in groups], final_div_factor=25) it = 0 next_val_it = it + 50 next_chkpt_it = chkpt_every_iters next_table_it = table_row_every_iters visual.show() running_loss = [0] for epoch in mb: bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb) for args in bar: with record_function("forward"): args = [x.to(device, non_blocking=True) for x in args] # zero the parameter gradients optimizer.zero_grad(set_to_none=True) with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): ps, loss = model(*args) loss = loss.mean() with record_function("backward"): scaler.scale(loss).backward() if clip_gradient_norm: scaler.unscale_(optimizer) # Since the gradients of optimizer's assigned params are unscaled, clips as usual: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm) scaler.step(optimizer) scaler.update() scheduler.step() if profiler is not None: profiler.step() with record_function("running_loss"): running_loss.append(loss.item()) running_loss = running_loss[-5:] avg_train_loss = sum(running_loss)/len(running_loss) if it >= next_chkpt_it: with record_function("checkpoint"): next_chkpt_it += chkpt_every_iters torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt') if it >= next_val_it: next_val_it += run_valid_every_iters with record_function("validation"): with record_function("model.eval"): model.eval() with torch.no_grad(): val_loss = 0 val_samples = 0 for args in val_loader: args = [x.to(device, non_blocking=True) for x in args] with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): ps, loss = model(*args) N = args[0].shape[0] val_loss += loss.mean().item() * N val_samples += N val_loss = val_loss / val_samples with record_function("model.train"): model.train() with record_function("plotting"): visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss) if it >= next_table_it: visual.add_table_row(it, avg_train_loss, val_loss) next_table_it += table_row_every_iters it += bs visual.on_iter(bar, it, avg_train_loss, val_loss) except KeyboardInterrupt: mb.write(f"interrupted") mb.show() pass finally: visual.add_table_row(it, avg_train_loss, val_loss) mb.show() visual.hide()