conex / espnet /utils /training /train_utils.py
tobiasc's picture
Initial commit
ad16788
import chainer
import logging
def check_early_stop(trainer, epochs):
"""Checks an early stopping trigger and warns the user if it's the case
:param trainer: The trainer used for training
:param epochs: The maximum number of epochs
"""
end_epoch = trainer.updater.get_iterator("main").epoch
if end_epoch < (epochs - 1):
logging.warning(
"Hit early stop at epoch "
+ str(end_epoch)
+ "\nYou can change the patience or set it to 0 to run all epochs"
)
def set_early_stop(trainer, args, is_lm=False):
"""Sets the early stop trigger given the program arguments
:param trainer: The trainer used for training
:param args: The program arguments
:param is_lm: If the trainer is for a LM (epoch instead of epochs)
"""
patience = args.patience
criterion = args.early_stop_criterion
epochs = args.epoch if is_lm else args.epochs
mode = "max" if "acc" in criterion else "min"
if patience > 0:
trainer.stop_trigger = chainer.training.triggers.EarlyStoppingTrigger(
monitor=criterion,
mode=mode,
patients=patience,
max_trigger=(epochs, "epoch"),
)