|
|
|
from mmcv.transforms import Compose |
|
from mmengine.hooks import Hook |
|
|
|
from mmdet.registry import HOOKS |
|
|
|
|
|
@HOOKS.register_module() |
|
class PipelineSwitchHook(Hook): |
|
"""Switch data pipeline at switch_epoch. |
|
|
|
Args: |
|
switch_epoch (int): switch pipeline at this epoch. |
|
switch_pipeline (list[dict]): the pipeline to switch to. |
|
""" |
|
|
|
def __init__(self, switch_epoch, switch_pipeline): |
|
self.switch_epoch = switch_epoch |
|
self.switch_pipeline = switch_pipeline |
|
self._restart_dataloader = False |
|
|
|
def before_train_epoch(self, runner): |
|
"""switch pipeline.""" |
|
epoch = runner.epoch |
|
train_loader = runner.train_dataloader |
|
if epoch == self.switch_epoch: |
|
runner.logger.info('Switch pipeline now!') |
|
|
|
|
|
|
|
train_loader.dataset.pipeline = Compose(self.switch_pipeline) |
|
if hasattr(train_loader, 'persistent_workers' |
|
) and train_loader.persistent_workers is True: |
|
train_loader._DataLoader__initialized = False |
|
train_loader._iterator = None |
|
self._restart_dataloader = True |
|
|
|
else: |
|
|
|
|
|
if self._restart_dataloader: |
|
train_loader._DataLoader__initialized = True |
|
|