Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
from mmengine.hooks import Hook | |
from mmengine.runner import Runner | |
from mmdet.registry import HOOKS | |
class CheckInvalidLossHook(Hook): | |
"""Check invalid loss hook. | |
This hook will regularly check whether the loss is valid | |
during training. | |
Args: | |
interval (int): Checking interval (every k iterations). | |
Default: 50. | |
""" | |
def __init__(self, interval: int = 50) -> None: | |
self.interval = interval | |
def after_train_iter(self, | |
runner: Runner, | |
batch_idx: int, | |
data_batch: Optional[dict] = None, | |
outputs: Optional[dict] = None) -> None: | |
"""Regularly check whether the loss is valid every n iterations. | |
Args: | |
runner (:obj:`Runner`): The runner of the training process. | |
batch_idx (int): The index of the current batch in the train loop. | |
data_batch (dict, Optional): Data from dataloader. | |
Defaults to None. | |
outputs (dict, Optional): Outputs from model. Defaults to None. | |
""" | |
if self.every_n_train_iters(runner, self.interval): | |
assert torch.isfinite(outputs['loss']), \ | |
runner.logger.info('loss become infinite or NaN!') | |