mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2026-05-18 13:17:53 +00:00
55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import inspect
|
|
from typing import Iterable, Union
|
|
|
|
import torch
|
|
|
|
from modelscope.utils.config import ConfigDict
|
|
from modelscope.utils.registry import Registry, build_from_cfg, default_group
|
|
|
|
OPTIMIZERS = Registry('optimizer')
|
|
|
|
|
|
def build_optimizer(model: Union[torch.nn.Module,
|
|
Iterable[torch.nn.parameter.Parameter]],
|
|
cfg: ConfigDict,
|
|
default_args: dict = None):
|
|
""" build optimizer from optimizer config dict
|
|
|
|
Args:
|
|
model: A torch.nn.Module or an iterable of parameters.
|
|
cfg (:obj:`ConfigDict`): config dict for optimizer object.
|
|
default_args (dict, optional): Default initialization arguments.
|
|
"""
|
|
if default_args is None:
|
|
default_args = {}
|
|
|
|
if isinstance(model, torch.nn.Module) or (hasattr(
|
|
model, 'module') and isinstance(model.module, torch.nn.Module)):
|
|
if hasattr(model, 'module'):
|
|
model = model.module
|
|
|
|
default_args['params'] = model.parameters()
|
|
else:
|
|
# Input is a iterable of parameters, this case fits for the scenario of user-defined parameter groups.
|
|
default_args['params'] = model
|
|
|
|
return build_from_cfg(cfg,
|
|
OPTIMIZERS,
|
|
group_key=default_group,
|
|
default_args=default_args)
|
|
|
|
|
|
def register_torch_optimizers():
|
|
for name, module in inspect.getmembers(torch.optim):
|
|
if name.startswith('__'):
|
|
continue
|
|
if inspect.isclass(module) and issubclass(module,
|
|
torch.optim.Optimizer):
|
|
OPTIMIZERS.register_module(default_group,
|
|
module_name=name,
|
|
module_cls=module)
|
|
|
|
|
|
register_torch_optimizers()
|