Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from mmengine.optim.scheduler import CosineAnnealingParamScheduler | |
from mmpretrain.registry import PARAM_SCHEDULERS | |
class WeightDecaySchedulerMixin: | |
"""A mixin class for learning rate schedulers.""" | |
def __init__(self, optimizer, *args, **kwargs): | |
super().__init__(optimizer, 'weight_decay', *args, **kwargs) | |
class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin, | |
CosineAnnealingParamScheduler): | |
"""Set the weight decay value of each parameter group using a cosine | |
annealing schedule. | |
If the weight decay was set to be 0 initially, the weight decay value will | |
be 0 constantly during the training. | |
""" | |
def _get_value(self) -> list: | |
"""Compute value using chainable form of the scheduler.""" | |
def _get_eta_min(base_value): | |
if self.eta_min_ratio is None: | |
return self.eta_min | |
return base_value * self.eta_min_ratio | |
if self.last_step == 0: | |
return [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: | |
weight_decay_value_list = [] | |
for base_value, group in zip(self.base_values, | |
self.optimizer.param_groups): | |
if base_value == 0: | |
group_value = 0 | |
else: | |
group_value = group[self.param_name] + ( | |
base_value - _get_eta_min(base_value)) * ( | |
1 - math.cos(math.pi / self.T_max)) / 2 | |
weight_decay_value_list.append(group_value) | |
return weight_decay_value_list | |
weight_decay_value_list = [] | |
for base_value, group in zip(self.base_values, | |
self.optimizer.param_groups): | |
if base_value == 0: | |
group_value = 0 | |
else: | |
group_value = ( | |
1 + math.cos(math.pi * self.last_step / self.T_max)) / ( | |
1 + math.cos(math.pi * | |
(self.last_step - 1) / self.T_max) | |
) * (group[self.param_name] - | |
_get_eta_min(base_value)) + _get_eta_min(base_value) | |
weight_decay_value_list.append(group_value) | |
return weight_decay_value_list | |