KyanChen's picture
init
f549064
raw
history blame
5.18 kB
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.optim.scheduler.lr_scheduler import LRSchedulerMixin
from mmengine.optim.scheduler.momentum_scheduler import MomentumSchedulerMixin
from mmengine.optim.scheduler.param_scheduler import INF, _ParamScheduler
from torch.optim import Optimizer
from mmdet.registry import PARAM_SCHEDULERS
@PARAM_SCHEDULERS.register_module()
class QuadraticWarmupParamScheduler(_ParamScheduler):
r"""Warm up the parameter value of each parameter group by quadratic
formula:
.. math::
X_{t} = X_{t-1} + \frac{2t+1}{{(end-begin)}^{2}} \times X_{base}
Args:
optimizer (Optimizer): Wrapped optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: Optimizer,
param_name: str,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
if end >= INF:
raise ValueError('``end`` must be less than infinity,'
'Please set ``end`` parameter of '
'``QuadraticWarmupScheduler`` as the '
'number of warmup end.')
self.total_iters = end - begin
super().__init__(
optimizer=optimizer,
param_name=param_name,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if self.last_step == 0:
return [
base_value * (2 * self.last_step + 1) / self.total_iters**2
for base_value in self.base_values
]
return [
group[self.param_name] + base_value *
(2 * self.last_step + 1) / self.total_iters**2
for base_value, group in zip(self.base_values,
self.optimizer.param_groups)
]
@PARAM_SCHEDULERS.register_module()
class QuadraticWarmupLR(LRSchedulerMixin, QuadraticWarmupParamScheduler):
"""Warm up the learning rate of each parameter group by quadratic formula.
Args:
optimizer (Optimizer): Wrapped optimizer.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
@PARAM_SCHEDULERS.register_module()
class QuadraticWarmupMomentum(MomentumSchedulerMixin,
QuadraticWarmupParamScheduler):
"""Warm up the momentum value of each parameter group by quadratic formula.
Args:
optimizer (Optimizer): Wrapped optimizer.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""