Spaces:
Runtime error
Runtime error
import os | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks.base import Callback | |
class IntervalModelCheckpoint(Callback): | |
""" | |
Save a checkpoint every N steps, instead of Lightning's default that checkpoints | |
based on validation loss. | |
""" | |
def __init__( | |
self, | |
dirpath, | |
save_intervals, | |
): | |
""" | |
Args: | |
save_step_frequency: how often to save in steps | |
prefix: add a prefix to the name, only used if | |
use_modelcheckpoint_filename=False | |
use_modelcheckpoint_filename: just use the ModelCheckpoint callback's | |
default filename, don't use ours. | |
""" | |
self.dirpath = dirpath | |
self.save_intervals = save_intervals | |
self.best_val_loss = 1e10 | |
def on_batch_end(self, trainer: pl.Trainer, _): | |
""" Check if we should save a checkpoint after every train batch """ | |
global_step = trainer.global_step | |
if (global_step + 1) in self.save_intervals: | |
trainer.run_evaluation() | |
val_loss = trainer.callback_metrics['val_loss'] | |
filename = f"steps={global_step+1:05d}-val_loss={val_loss:0.8f}.ckpt" | |
ckpt_path = os.path.join(self.dirpath, filename) | |
trainer.save_checkpoint(ckpt_path) | |
if val_loss < self.best_val_loss: | |
best_ckpt_path = os.path.join(self.dirpath, 'best.ckpt') | |
trainer.save_checkpoint(best_ckpt_path) | |
self.best_val_loss = val_loss | |