|
|
|
from typing import Sequence |
|
|
|
from mmengine.hooks import Hook |
|
from mmengine.model import is_model_wrapper |
|
|
|
from mmdet.registry import HOOKS |
|
|
|
|
|
@HOOKS.register_module() |
|
class YOLOXModeSwitchHook(Hook): |
|
"""Switch the mode of YOLOX during training. |
|
|
|
This hook turns off the mosaic and mixup data augmentation and switches |
|
to use L1 loss in bbox_head. |
|
|
|
Args: |
|
num_last_epochs (int): The number of latter epochs in the end of the |
|
training to close the data augmentation and switch to L1 loss. |
|
Defaults to 15. |
|
skip_type_keys (Sequence[str], optional): Sequence of type string to be |
|
skip pipeline. Defaults to ('Mosaic', 'RandomAffine', 'MixUp'). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_last_epochs: int = 15, |
|
skip_type_keys: Sequence[str] = ('Mosaic', 'RandomAffine', 'MixUp') |
|
) -> None: |
|
self.num_last_epochs = num_last_epochs |
|
self.skip_type_keys = skip_type_keys |
|
self._restart_dataloader = False |
|
|
|
def before_train_epoch(self, runner) -> None: |
|
"""Close mosaic and mixup augmentation and switches to use L1 loss.""" |
|
epoch = runner.epoch |
|
train_loader = runner.train_dataloader |
|
model = runner.model |
|
|
|
if is_model_wrapper(model): |
|
model = model.module |
|
if (epoch + 1) == runner.max_epochs - self.num_last_epochs: |
|
runner.logger.info('No mosaic and mixup aug now!') |
|
|
|
|
|
|
|
train_loader.dataset.update_skip_type_keys(self.skip_type_keys) |
|
if hasattr(train_loader, 'persistent_workers' |
|
) and train_loader.persistent_workers is True: |
|
train_loader._DataLoader__initialized = False |
|
train_loader._iterator = None |
|
self._restart_dataloader = True |
|
runner.logger.info('Add additional L1 loss now!') |
|
model.bbox_head.use_l1 = True |
|
else: |
|
|
|
|
|
if self._restart_dataloader: |
|
train_loader._DataLoader__initialized = True |
|
|