Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from mmcv.cnn import VGG | |
from mmengine.hooks import Hook | |
from mmengine.runner import Runner | |
from mmdet.registry import HOOKS | |
class NumClassCheckHook(Hook): | |
"""Check whether the `num_classes` in head matches the length of `classes` | |
in `dataset.metainfo`.""" | |
def _check_head(self, runner: Runner, mode: str) -> None: | |
"""Check whether the `num_classes` in head matches the length of | |
`classes` in `dataset.metainfo`. | |
Args: | |
runner (:obj:`Runner`): The runner of the training or evaluation | |
process. | |
""" | |
assert mode in ['train', 'val'] | |
model = runner.model | |
dataset = runner.train_dataloader.dataset if mode == 'train' else \ | |
runner.val_dataloader.dataset | |
if dataset.metainfo.get('classes', None) is None: | |
runner.logger.warning( | |
f'Please set `classes` ' | |
f'in the {dataset.__class__.__name__} `metainfo` and' | |
f'check if it is consistent with the `num_classes` ' | |
f'of head') | |
else: | |
classes = dataset.metainfo['classes'] | |
assert type(classes) is not str, \ | |
(f'`classes` in {dataset.__class__.__name__}' | |
f'should be a tuple of str.' | |
f'Add comma if number of classes is 1 as ' | |
f'classes = ({classes},)') | |
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead | |
for name, module in model.named_modules(): | |
if hasattr(module, 'num_classes') and not name.endswith( | |
'rpn_head') and not isinstance( | |
module, (VGG, FusedSemanticHead)): | |
assert module.num_classes == len(classes), \ | |
(f'The `num_classes` ({module.num_classes}) in ' | |
f'{module.__class__.__name__} of ' | |
f'{model.__class__.__name__} does not matches ' | |
f'the length of `classes` ' | |
f'{len(classes)}) in ' | |
f'{dataset.__class__.__name__}') | |
def before_train_epoch(self, runner: Runner) -> None: | |
"""Check whether the training dataset is compatible with head. | |
Args: | |
runner (:obj:`Runner`): The runner of the training or evaluation | |
process. | |
""" | |
self._check_head(runner, 'train') | |
def before_val_epoch(self, runner: Runner) -> None: | |
"""Check whether the dataset in val epoch is compatible with head. | |
Args: | |
runner (:obj:`Runner`): The runner of the training or evaluation | |
process. | |
""" | |
self._check_head(runner, 'val') | |