You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
1.6 KiB
56 lines
1.6 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from mmcv.runner import build_optimizer
|
|
from mmcv.utils import Registry
|
|
|
|
OPTIMIZERS = Registry('optimizers')
|
|
|
|
|
|
def build_optimizers(model, cfgs):
|
|
"""Build multiple optimizers from configs.
|
|
|
|
If `cfgs` contains several dicts for optimizers, then a dict for each
|
|
constructed optimizers will be returned.
|
|
If `cfgs` only contains one optimizer config, the constructed optimizer
|
|
itself will be returned.
|
|
|
|
For example,
|
|
|
|
1) Multiple optimizer configs:
|
|
|
|
.. code-block:: python
|
|
|
|
optimizer_cfg = dict(
|
|
model1=dict(type='SGD', lr=lr),
|
|
model2=dict(type='SGD', lr=lr))
|
|
|
|
The return dict is
|
|
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
|
|
|
|
2) Single optimizer config:
|
|
|
|
.. code-block:: python
|
|
|
|
optimizer_cfg = dict(type='SGD', lr=lr)
|
|
|
|
The return is ``torch.optim.Optimizer``.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`): The model with parameters to be optimized.
|
|
cfgs (dict): The config dict of the optimizer.
|
|
|
|
Returns:
|
|
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
|
|
The initialized optimizers.
|
|
"""
|
|
optimizers = {}
|
|
if hasattr(model, 'module'):
|
|
model = model.module
|
|
# determine whether 'cfgs' has several dicts for optimizers
|
|
if all(isinstance(v, dict) for v in cfgs.values()):
|
|
for key, cfg in cfgs.items():
|
|
cfg_ = cfg.copy()
|
|
module = getattr(model, key)
|
|
optimizers[key] = build_optimizer(module, cfg_)
|
|
return optimizers
|
|
|
|
return build_optimizer(model, cfgs)
|
|
|