Files
EasyFace/modelscope/trainers/optimizer/builder.py
2023-03-02 11:17:26 +08:00

55 lines
1.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
from typing import Iterable, Union
import torch
from modelscope.utils.config import ConfigDict
from modelscope.utils.registry import Registry, build_from_cfg, default_group
OPTIMIZERS = Registry('optimizer')
def build_optimizer(model: Union[torch.nn.Module,
Iterable[torch.nn.parameter.Parameter]],
cfg: ConfigDict,
default_args: dict = None):
""" build optimizer from optimizer config dict
Args:
model: A torch.nn.Module or an iterable of parameters.
cfg (:obj:`ConfigDict`): config dict for optimizer object.
default_args (dict, optional): Default initialization arguments.
"""
if default_args is None:
default_args = {}
if isinstance(model, torch.nn.Module) or (hasattr(
model, 'module') and isinstance(model.module, torch.nn.Module)):
if hasattr(model, 'module'):
model = model.module
default_args['params'] = model.parameters()
else:
# Input is a iterable of parameters, this case fits for the scenario of user-defined parameter groups.
default_args['params'] = model
return build_from_cfg(cfg,
OPTIMIZERS,
group_key=default_group,
default_args=default_args)
def register_torch_optimizers():
for name, module in inspect.getmembers(torch.optim):
if name.startswith('__'):
continue
if inspect.isclass(module) and issubclass(module,
torch.optim.Optimizer):
OPTIMIZERS.register_module(default_group,
module_name=name,
module_cls=module)
register_torch_optimizers()