Spaces:
Running
Running
File size: 2,362 Bytes
560b597 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import numpy as np
class CosineScheduler(object):
def __init__(
self,
optimizer,
warmup_iters,
total_iters,
key,
overwrite=False,
init_value=None,
base_value=None,
final_value=None,
step_init=-1,
):
super().__init__()
self.iter = step_init
self.overwrite = overwrite
self.optimizer = optimizer
self.base_value = base_value
self.init_value = init_value
self.final_value = final_value
self.total_iters = total_iters
self.warmup_iters = warmup_iters
self.key = key
self.schedulers = [
self.get_schedulers(group) for group in optimizer.param_groups
]
def get_schedulers(self, group):
init_value = group.get(self.key + "_init", self.init_value)
base_value = group.get(self.key + "_base", self.base_value)
final_value = group.get(self.key + "_final", self.final_value)
warmup_iters = self.warmup_iters
total_iters = self.total_iters
if self.overwrite:
final_value = self.final_value
# normalize in 0,1, then apply function (power) and denormalize
normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
normalized_schedule = np.power(normalized_schedule, 2)
warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
# main scheduling
iters = np.arange(total_iters - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) * (
1 + np.cos(np.pi * iters / len(iters))
)
return np.concatenate((warmup_schedule, schedule))
def step(self):
self.iter = self.iter + 1
vals = self[self.iter]
for group, val in zip(self.optimizer.param_groups, vals):
if isinstance(group[self.key], (tuple, list)):
val = (val, *group[self.key][1:])
group[self.key] = val
def __getitem__(self, it):
it = min(it, self.total_iters - 1)
return [scheduler[it] for scheduler in self.schedulers]
def get(self):
return [group[self.key] for group in self.optimizer.param_groups]
|