|
|
|
from mmcv.cnn import VGG |
|
from mmengine.hooks import Hook |
|
from mmengine.runner import Runner |
|
|
|
from mmdet.registry import HOOKS |
|
|
|
|
|
@HOOKS.register_module() |
|
class NumClassCheckHook(Hook): |
|
"""Check whether the `num_classes` in head matches the length of `classes` |
|
in `dataset.metainfo`.""" |
|
|
|
def _check_head(self, runner: Runner, mode: str) -> None: |
|
"""Check whether the `num_classes` in head matches the length of |
|
`classes` in `dataset.metainfo`. |
|
|
|
Args: |
|
runner (:obj:`Runner`): The runner of the training or evaluation |
|
process. |
|
""" |
|
assert mode in ['train', 'val'] |
|
model = runner.model |
|
dataset = runner.train_dataloader.dataset if mode == 'train' else \ |
|
runner.val_dataloader.dataset |
|
if dataset.metainfo.get('classes', None) is None: |
|
runner.logger.warning( |
|
f'Please set `classes` ' |
|
f'in the {dataset.__class__.__name__} `metainfo` and' |
|
f'check if it is consistent with the `num_classes` ' |
|
f'of head') |
|
else: |
|
classes = dataset.metainfo['classes'] |
|
assert type(classes) is not str, \ |
|
(f'`classes` in {dataset.__class__.__name__}' |
|
f'should be a tuple of str.' |
|
f'Add comma if number of classes is 1 as ' |
|
f'classes = ({classes},)') |
|
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead |
|
for name, module in model.named_modules(): |
|
if hasattr(module, 'num_classes') and not name.endswith( |
|
'rpn_head') and not isinstance( |
|
module, (VGG, FusedSemanticHead)): |
|
assert module.num_classes == len(classes), \ |
|
(f'The `num_classes` ({module.num_classes}) in ' |
|
f'{module.__class__.__name__} of ' |
|
f'{model.__class__.__name__} does not matches ' |
|
f'the length of `classes` ' |
|
f'{len(classes)}) in ' |
|
f'{dataset.__class__.__name__}') |
|
|
|
def before_train_epoch(self, runner: Runner) -> None: |
|
"""Check whether the training dataset is compatible with head. |
|
|
|
Args: |
|
runner (:obj:`Runner`): The runner of the training or evaluation |
|
process. |
|
""" |
|
self._check_head(runner, 'train') |
|
|
|
def before_val_epoch(self, runner: Runner) -> None: |
|
"""Check whether the dataset in val epoch is compatible with head. |
|
|
|
Args: |
|
runner (:obj:`Runner`): The runner of the training or evaluation |
|
process. |
|
""" |
|
self._check_head(runner, 'val') |
|
|