Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import OrderedDict | |
from copy import deepcopy | |
from mmcv.transforms import Compose | |
from mmengine.hooks import Hook | |
from mmengine.model import is_model_wrapper | |
from mmpretrain.models.utils import RandomBatchAugment | |
from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS | |
class SwitchRecipeHook(Hook): | |
"""switch recipe during the training loop, including train pipeline, batch | |
augments and loss currently. | |
Args: | |
schedule (list): Every item of the schedule list should be a dict, and | |
the dict should have ``action_epoch`` and some of | |
``train_pipeline``, ``train_augments`` and ``loss`` keys: | |
- ``action_epoch`` (int): switch training recipe at which epoch. | |
- ``train_pipeline`` (list, optional): The new data pipeline of the | |
train dataset. If not specified, keep the original settings. | |
- ``batch_augments`` (dict | None, optional): The new batch | |
augmentations of during training. See :mod:`Batch Augmentations | |
<mmpretrain.models.utils.batch_augments>` for more details. | |
If None, disable batch augmentations. If not specified, keep the | |
original settings. | |
- ``loss`` (dict, optional): The new loss module config. If not | |
specified, keep the original settings. | |
Example: | |
To use this hook in config files. | |
.. code:: python | |
custom_hooks = [ | |
dict( | |
type='SwitchRecipeHook', | |
schedule=[ | |
dict( | |
action_epoch=30, | |
train_pipeline=pipeline_after_30e, | |
batch_augments=batch_augments_after_30e, | |
loss=loss_after_30e, | |
), | |
dict( | |
action_epoch=60, | |
# Disable batch augmentations after 60e | |
# and keep other settings. | |
batch_augments=None, | |
), | |
] | |
) | |
] | |
""" | |
priority = 'NORMAL' | |
def __init__(self, schedule): | |
recipes = {} | |
for recipe in schedule: | |
assert 'action_epoch' in recipe, \ | |
'Please set `action_epoch` in every item ' \ | |
'of the `schedule` in the SwitchRecipeHook.' | |
recipe = deepcopy(recipe) | |
if 'train_pipeline' in recipe: | |
recipe['train_pipeline'] = Compose(recipe['train_pipeline']) | |
if 'batch_augments' in recipe: | |
batch_augments = recipe['batch_augments'] | |
if isinstance(batch_augments, dict): | |
batch_augments = RandomBatchAugment(**batch_augments) | |
recipe['batch_augments'] = batch_augments | |
if 'loss' in recipe: | |
loss = recipe['loss'] | |
if isinstance(loss, dict): | |
loss = MODELS.build(loss) | |
recipe['loss'] = loss | |
action_epoch = recipe.pop('action_epoch') | |
assert action_epoch not in recipes, \ | |
f'The `action_epoch` {action_epoch} is repeated ' \ | |
'in the SwitchRecipeHook.' | |
recipes[action_epoch] = recipe | |
self.schedule = OrderedDict(sorted(recipes.items())) | |
def before_train(self, runner) -> None: | |
"""before run setting. If resume form a checkpoint, do all switch | |
before the current epoch. | |
Args: | |
runner (Runner): The runner of the training, validation or testing | |
process. | |
""" | |
if runner._resume: | |
for action_epoch, recipe in self.schedule.items(): | |
if action_epoch >= runner.epoch + 1: | |
break | |
self._do_switch(runner, recipe, | |
f' (resume recipe of epoch {action_epoch})') | |
def before_train_epoch(self, runner): | |
"""do before train epoch.""" | |
recipe = self.schedule.get(runner.epoch + 1, None) | |
if recipe is not None: | |
self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}') | |
def _do_switch(self, runner, recipe, extra_info=''): | |
"""do the switch aug process.""" | |
if 'batch_augments' in recipe: | |
self._switch_batch_augments(runner, recipe['batch_augments']) | |
runner.logger.info(f'Switch batch augments{extra_info}.') | |
if 'train_pipeline' in recipe: | |
self._switch_train_pipeline(runner, recipe['train_pipeline']) | |
runner.logger.info(f'Switch train pipeline{extra_info}.') | |
if 'loss' in recipe: | |
self._switch_loss(runner, recipe['loss']) | |
runner.logger.info(f'Switch loss{extra_info}.') | |
def _switch_batch_augments(runner, batch_augments): | |
"""switch the train augments.""" | |
model = runner.model | |
if is_model_wrapper(model): | |
model = model.module | |
model.data_preprocessor.batch_augments = batch_augments | |
def _switch_train_pipeline(runner, train_pipeline): | |
"""switch the train loader dataset pipeline.""" | |
def switch_pipeline(dataset, pipeline): | |
if hasattr(dataset, 'pipeline'): | |
# for usual dataset | |
dataset.pipeline = pipeline | |
elif hasattr(dataset, 'datasets'): | |
# for concat dataset wrapper | |
for ds in dataset.datasets: | |
switch_pipeline(ds, pipeline) | |
elif hasattr(dataset, 'dataset'): | |
# for other dataset wrappers | |
switch_pipeline(dataset.dataset, pipeline) | |
else: | |
raise RuntimeError( | |
'Cannot access the `pipeline` of the dataset.') | |
train_loader = runner.train_loop.dataloader | |
switch_pipeline(train_loader.dataset, train_pipeline) | |
# To restart the iterator of dataloader when `persistent_workers=True` | |
train_loader._iterator = None | |
def _switch_loss(runner, loss_module): | |
"""switch the loss module.""" | |
model = runner.model | |
if is_model_wrapper(model, MODEL_WRAPPERS): | |
model = model.module | |
if hasattr(model, 'loss_module'): | |
model.loss_module = loss_module | |
elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'): | |
model.head.loss_module = loss_module | |
else: | |
raise RuntimeError('Cannot access the `loss_module` of the model.') | |