File size: 2,121 Bytes
3a18eba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from typing import Tuple, Optional, Callable
import torch
from torch.optim.optimizer import Optimizer
# functions
def exists(val):
return val is not None
# update functions
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay
p.data.mul_(1 - lr * wd)
# weight update
update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_()
p.add_(update, alpha = -lr)
# decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2)
# class
class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
defaults = dict(
lr = lr,
betas = betas,
weight_decay = weight_decay
)
super().__init__(params, defaults)
self.update_fn = update_fn
if use_triton:
from lion_pytorch.triton import update_fn as triton_update_fn
self.update_fn = triton_update_fn
@torch.no_grad()
def step(
self,
closure: Optional[Callable] = None
):
loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]
# init state - exponential moving average of gradient values
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
self.update_fn(
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2
)
return loss
|