Files
2023-03-02 11:17:26 +08:00

24 lines
754 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from .builder import PARALLEL
def is_parallel(module):
"""Check if a module is wrapped by parallel object.
The following modules are regarded as parallel object:
- torch.nn.parallel.DataParallel
- torch.nn.parallel.distributed.DistributedDataParallel
You may add you own parallel object by registering it to `modelscope.parallel.PARALLEL`.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: True if the is wrapped by parallel object.
"""
module_wrappers = []
for group, module_dict in PARALLEL.modules.items():
module_wrappers.extend(list(module_dict.values()))
return isinstance(module, tuple(module_wrappers))