ttxskk
update
d7e58f0
raw
history blame
1.63 kB
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import build_optimizer
from mmcv.utils import Registry
OPTIMIZERS = Registry('optimizers')
def build_optimizers(model, cfgs):
"""Build multiple optimizers from configs. If `cfgs` contains several dicts
for optimizers, then a dict for each constructed optimizers will be
returned. If `cfgs` only contains one optimizer config, the constructed
optimizer itself will be returned. For example,
1) Multiple optimizer configs:
.. code-block:: python
optimizer_cfg = dict(
model1=dict(type='SGD', lr=lr),
model2=dict(type='SGD', lr=lr))
The return dict is
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
2) Single optimizer config:
.. code-block:: python
optimizer_cfg = dict(type='SGD', lr=lr)
The return is ``torch.optim.Optimizer``.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
cfgs (dict): The config dict of the optimizer.
Returns:
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
The initialized optimizers.
"""
optimizers = {}
if hasattr(model, 'module'):
model = model.module
# determine whether 'cfgs' has several dicts for optimizers
if all(isinstance(v, dict) for v in cfgs.values()):
for key, cfg in cfgs.items():
cfg_ = cfg.copy()
module = getattr(model, key)
optimizers[key] = build_optimizer(module, cfg_)
return optimizers
return build_optimizer(model, cfgs)