mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2026-04-08 21:30:17 +00:00
24 lines
754 B
Python
24 lines
754 B
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from .builder import PARALLEL
|
|
|
|
|
|
def is_parallel(module):
|
|
"""Check if a module is wrapped by parallel object.
|
|
|
|
The following modules are regarded as parallel object:
|
|
- torch.nn.parallel.DataParallel
|
|
- torch.nn.parallel.distributed.DistributedDataParallel
|
|
You may add you own parallel object by registering it to `modelscope.parallel.PARALLEL`.
|
|
|
|
Args:
|
|
module (nn.Module): The module to be checked.
|
|
|
|
Returns:
|
|
bool: True if the is wrapped by parallel object.
|
|
"""
|
|
module_wrappers = []
|
|
for group, module_dict in PARALLEL.modules.items():
|
|
module_wrappers.extend(list(module_dict.values()))
|
|
|
|
return isinstance(module, tuple(module_wrappers))
|