TTP / mmpretrain /engine /schedulers /weight_decay_scheduler.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import math
from mmengine.optim.scheduler import CosineAnnealingParamScheduler
from mmpretrain.registry import PARAM_SCHEDULERS
class WeightDecaySchedulerMixin:
"""A mixin class for learning rate schedulers."""
def __init__(self, optimizer, *args, **kwargs):
super().__init__(optimizer, 'weight_decay', *args, **kwargs)
@PARAM_SCHEDULERS.register_module()
class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin,
CosineAnnealingParamScheduler):
"""Set the weight decay value of each parameter group using a cosine
annealing schedule.
If the weight decay was set to be 0 initially, the weight decay value will
be 0 constantly during the training.
"""
def _get_value(self) -> list:
"""Compute value using chainable form of the scheduler."""
def _get_eta_min(base_value):
if self.eta_min_ratio is None:
return self.eta_min
return base_value * self.eta_min_ratio
if self.last_step == 0:
return [
group[self.param_name] for group in self.optimizer.param_groups
]
elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0:
weight_decay_value_list = []
for base_value, group in zip(self.base_values,
self.optimizer.param_groups):
if base_value == 0:
group_value = 0
else:
group_value = group[self.param_name] + (
base_value - _get_eta_min(base_value)) * (
1 - math.cos(math.pi / self.T_max)) / 2
weight_decay_value_list.append(group_value)
return weight_decay_value_list
weight_decay_value_list = []
for base_value, group in zip(self.base_values,
self.optimizer.param_groups):
if base_value == 0:
group_value = 0
else:
group_value = (
1 + math.cos(math.pi * self.last_step / self.T_max)) / (
1 + math.cos(math.pi *
(self.last_step - 1) / self.T_max)
) * (group[self.param_name] -
_get_eta_min(base_value)) + _get_eta_min(base_value)
weight_decay_value_list.append(group_value)
return weight_decay_value_list