Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved | |
from mmengine.hooks import Hook | |
from mmengine.utils import is_seq_of | |
from mmcls.registry import HOOKS | |
class ClassNumCheckHook(Hook): | |
"""Class Number Check HOOK.""" | |
def _check_head(self, runner, dataset): | |
"""Check whether the `num_classes` in head matches the length of | |
`CLASSES` in `dataset`. | |
Args: | |
runner (obj:`Runner`): runner object. | |
dataset (obj: `BaseDataset`): the dataset to check. | |
""" | |
model = runner.model | |
if dataset.CLASSES is None: | |
runner.logger.warning( | |
f'Please set class information in `metainfo` ' | |
f'in the {dataset.__class__.__name__} and' | |
f'check if it is consistent with the `num_classes` ' | |
f'of head') | |
else: | |
assert is_seq_of(dataset.CLASSES, str), \ | |
(f'Class information in `metainfo` in ' | |
f'{dataset.__class__.__name__} should be a tuple of str.') | |
for _, module in model.named_modules(): | |
if hasattr(module, 'num_classes'): | |
assert module.num_classes == len(dataset.CLASSES), \ | |
(f'The `num_classes` ({module.num_classes}) in ' | |
f'{module.__class__.__name__} of ' | |
f'{model.__class__.__name__} does not matches ' | |
f'the length of class information in `metainfo` ' | |
f'{len(dataset.CLASSES)}) in ' | |
f'{dataset.__class__.__name__}') | |
def before_train(self, runner): | |
"""Check whether the training dataset is compatible with head. | |
Args: | |
runner (obj: `IterBasedRunner`): Iter based Runner. | |
""" | |
self._check_head(runner, runner.train_dataloader.dataset) | |
def before_val(self, runner): | |
"""Check whether the validation dataset is compatible with head. | |
Args: | |
runner (obj:`IterBasedRunner`): Iter based Runner. | |
""" | |
self._check_head(runner, runner.val_dataloader.dataset) | |
def before_test(self, runner): | |
"""Check whether the test dataset is compatible with head. | |
Args: | |
runner (obj:`IterBasedRunner`): Iter based Runner. | |
""" | |
self._check_head(runner, runner.test_dataloader.dataset) | |