ai-photo-gallery / mmdet /engine /hooks /num_class_check_hook.py
KyanChen's picture
init
f549064
raw
history blame
No virus
2.81 kB
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import VGG
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmdet.registry import HOOKS
@HOOKS.register_module()
class NumClassCheckHook(Hook):
"""Check whether the `num_classes` in head matches the length of `classes`
in `dataset.metainfo`."""
def _check_head(self, runner: Runner, mode: str) -> None:
"""Check whether the `num_classes` in head matches the length of
`classes` in `dataset.metainfo`.
Args:
runner (:obj:`Runner`): The runner of the training or evaluation
process.
"""
assert mode in ['train', 'val']
model = runner.model
dataset = runner.train_dataloader.dataset if mode == 'train' else \
runner.val_dataloader.dataset
if dataset.metainfo.get('classes', None) is None:
runner.logger.warning(
f'Please set `classes` '
f'in the {dataset.__class__.__name__} `metainfo` and'
f'check if it is consistent with the `num_classes` '
f'of head')
else:
classes = dataset.metainfo['classes']
assert type(classes) is not str, \
(f'`classes` in {dataset.__class__.__name__}'
f'should be a tuple of str.'
f'Add comma if number of classes is 1 as '
f'classes = ({classes},)')
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
for name, module in model.named_modules():
if hasattr(module, 'num_classes') and not name.endswith(
'rpn_head') and not isinstance(
module, (VGG, FusedSemanticHead)):
assert module.num_classes == len(classes), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
f'{model.__class__.__name__} does not matches '
f'the length of `classes` '
f'{len(classes)}) in '
f'{dataset.__class__.__name__}')
def before_train_epoch(self, runner: Runner) -> None:
"""Check whether the training dataset is compatible with head.
Args:
runner (:obj:`Runner`): The runner of the training or evaluation
process.
"""
self._check_head(runner, 'train')
def before_val_epoch(self, runner: Runner) -> None:
"""Check whether the dataset in val epoch is compatible with head.
Args:
runner (:obj:`Runner`): The runner of the training or evaluation
process.
"""
self._check_head(runner, 'val')