|
import chainer |
|
import logging |
|
|
|
|
|
def check_early_stop(trainer, epochs): |
|
"""Checks an early stopping trigger and warns the user if it's the case |
|
|
|
:param trainer: The trainer used for training |
|
:param epochs: The maximum number of epochs |
|
""" |
|
end_epoch = trainer.updater.get_iterator("main").epoch |
|
if end_epoch < (epochs - 1): |
|
logging.warning( |
|
"Hit early stop at epoch " |
|
+ str(end_epoch) |
|
+ "\nYou can change the patience or set it to 0 to run all epochs" |
|
) |
|
|
|
|
|
def set_early_stop(trainer, args, is_lm=False): |
|
"""Sets the early stop trigger given the program arguments |
|
|
|
:param trainer: The trainer used for training |
|
:param args: The program arguments |
|
:param is_lm: If the trainer is for a LM (epoch instead of epochs) |
|
""" |
|
patience = args.patience |
|
criterion = args.early_stop_criterion |
|
epochs = args.epoch if is_lm else args.epochs |
|
mode = "max" if "acc" in criterion else "min" |
|
if patience > 0: |
|
trainer.stop_trigger = chainer.training.triggers.EarlyStoppingTrigger( |
|
monitor=criterion, |
|
mode=mode, |
|
patients=patience, |
|
max_trigger=(epochs, "epoch"), |
|
) |
|
|