Spaces:
Runtime error
Runtime error
File size: 747 Bytes
0102e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
class SGD(torch.optim.SGD):
"""Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr'
Note that
the arguments of the optimizer invoked by AbsTask.main()
must have default value except for 'param'.
I can't understand why only SGD.lr doesn't have the default value.
"""
def __init__(
self,
params,
lr: float = 0.1,
momentum: float = 0.0,
dampening: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
):
super().__init__(
params,
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
|