|
|
|
import copy |
|
import inspect |
|
|
|
import torch |
|
|
|
from ...utils import Registry, build_from_cfg |
|
|
|
OPTIMIZERS = Registry('optimizer') |
|
OPTIMIZER_BUILDERS = Registry('optimizer builder') |
|
|
|
|
|
def register_torch_optimizers(): |
|
torch_optimizers = [] |
|
for module_name in dir(torch.optim): |
|
if module_name.startswith('__'): |
|
continue |
|
_optim = getattr(torch.optim, module_name) |
|
if inspect.isclass(_optim) and issubclass(_optim, |
|
torch.optim.Optimizer): |
|
OPTIMIZERS.register_module()(_optim) |
|
torch_optimizers.append(module_name) |
|
return torch_optimizers |
|
|
|
|
|
TORCH_OPTIMIZERS = register_torch_optimizers() |
|
|
|
|
|
def build_optimizer_constructor(cfg): |
|
return build_from_cfg(cfg, OPTIMIZER_BUILDERS) |
|
|
|
|
|
def build_optimizer(model, cfg): |
|
optimizer_cfg = copy.deepcopy(cfg) |
|
constructor_type = optimizer_cfg.pop('constructor', |
|
'DefaultOptimizerConstructor') |
|
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) |
|
optim_constructor = build_optimizer_constructor( |
|
dict( |
|
type=constructor_type, |
|
optimizer_cfg=optimizer_cfg, |
|
paramwise_cfg=paramwise_cfg)) |
|
optimizer = optim_constructor(model) |
|
return optimizer |
|
|