|
from chainer.training.extension import Extension |
|
|
|
|
|
class TensorboardLogger(Extension): |
|
"""A tensorboard logger extension""" |
|
|
|
default_name = "espnet_tensorboard_logger" |
|
|
|
def __init__( |
|
self, logger, att_reporter=None, ctc_reporter=None, entries=None, epoch=0 |
|
): |
|
"""Init the extension |
|
|
|
:param SummaryWriter logger: The logger to use |
|
:param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter |
|
:param entries: The entries to watch |
|
:param int epoch: The starting epoch |
|
""" |
|
self._entries = entries |
|
self._att_reporter = att_reporter |
|
self._ctc_reporter = ctc_reporter |
|
self._logger = logger |
|
self._epoch = epoch |
|
|
|
def __call__(self, trainer): |
|
"""Updates the events file with the new values |
|
|
|
:param trainer: The trainer |
|
""" |
|
observation = trainer.observation |
|
for k, v in observation.items(): |
|
if (self._entries is not None) and (k not in self._entries): |
|
continue |
|
if k is not None and v is not None: |
|
if "cupy" in str(type(v)): |
|
v = v.get() |
|
if "cupy" in str(type(k)): |
|
k = k.get() |
|
self._logger.add_scalar(k, v, trainer.updater.iteration) |
|
if ( |
|
self._att_reporter is not None |
|
and trainer.updater.get_iterator("main").epoch > self._epoch |
|
): |
|
self._epoch = trainer.updater.get_iterator("main").epoch |
|
self._att_reporter.log_attentions(self._logger, trainer.updater.iteration) |
|
if ( |
|
self._ctc_reporter is not None |
|
and trainer.updater.get_iterator("main").epoch > self._epoch |
|
): |
|
self._epoch = trainer.updater.get_iterator("main").epoch |
|
self._ctc_reporter.log_ctc_probs(self._logger, trainer.updater.iteration) |
|
|