from chainer.training.extension import Extension class TensorboardLogger(Extension): """A tensorboard logger extension""" default_name = "espnet_tensorboard_logger" def __init__( self, logger, att_reporter=None, ctc_reporter=None, entries=None, epoch=0 ): """Init the extension :param SummaryWriter logger: The logger to use :param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter :param entries: The entries to watch :param int epoch: The starting epoch """ self._entries = entries self._att_reporter = att_reporter self._ctc_reporter = ctc_reporter self._logger = logger self._epoch = epoch def __call__(self, trainer): """Updates the events file with the new values :param trainer: The trainer """ observation = trainer.observation for k, v in observation.items(): if (self._entries is not None) and (k not in self._entries): continue if k is not None and v is not None: if "cupy" in str(type(v)): v = v.get() if "cupy" in str(type(k)): k = k.get() self._logger.add_scalar(k, v, trainer.updater.iteration) if ( self._att_reporter is not None and trainer.updater.get_iterator("main").epoch > self._epoch ): self._epoch = trainer.updater.get_iterator("main").epoch self._att_reporter.log_attentions(self._logger, trainer.updater.iteration) if ( self._ctc_reporter is not None and trainer.updater.get_iterator("main").epoch > self._epoch ): self._epoch = trainer.updater.get_iterator("main").epoch self._ctc_reporter.log_ctc_probs(self._logger, trainer.updater.iteration)