|
from typing import Tuple, Optional, Callable |
|
|
|
import torch |
|
from torch.optim.optimizer import Optimizer |
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
|
|
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): |
|
|
|
|
|
p.data.mul_(1 - lr * wd) |
|
|
|
|
|
|
|
update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_() |
|
p.add_(update, alpha = -lr) |
|
|
|
|
|
|
|
exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2) |
|
|
|
|
|
|
|
class Lion(Optimizer): |
|
def __init__( |
|
self, |
|
params, |
|
lr: float = 1e-4, |
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
weight_decay: float = 0.0, |
|
use_triton: bool = False |
|
): |
|
assert lr > 0. |
|
assert all([0. <= beta <= 1. for beta in betas]) |
|
|
|
defaults = dict( |
|
lr = lr, |
|
betas = betas, |
|
weight_decay = weight_decay |
|
) |
|
|
|
super().__init__(params, defaults) |
|
|
|
self.update_fn = update_fn |
|
|
|
if use_triton: |
|
from lion_pytorch.triton import update_fn as triton_update_fn |
|
self.update_fn = triton_update_fn |
|
|
|
@torch.no_grad() |
|
def step( |
|
self, |
|
closure: Optional[Callable] = None |
|
): |
|
|
|
loss = None |
|
if exists(closure): |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in filter(lambda p: exists(p.grad), group['params']): |
|
|
|
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p] |
|
|
|
|
|
|
|
if len(state) == 0: |
|
state['exp_avg'] = torch.zeros_like(p) |
|
|
|
exp_avg = state['exp_avg'] |
|
|
|
self.update_fn( |
|
p, |
|
grad, |
|
exp_avg, |
|
lr, |
|
wd, |
|
beta1, |
|
beta2 |
|
) |
|
|
|
return loss |
|
|