Files
EasyFace/modelscope/models/base/base_torch_model.py
2023-03-02 11:17:26 +08:00

129 lines
4.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch import nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
from modelscope.utils.checkpoint import (save_checkpoint, save_configuration,
save_pretrained)
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.logger import get_logger
from .base_model import Model
logger = get_logger()
class TorchModel(Model, torch.nn.Module):
""" Base model interface for pytorch
"""
def __init__(self, model_dir=None, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
torch.nn.Module.__init__(self)
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
# Adapting a model with only one dict arg, and the arg name must be input or inputs
if func_receive_dict_inputs(self.forward):
return self.postprocess(self.forward(args[0], **kwargs))
else:
return self.postprocess(self.forward(*args, **kwargs))
def _load_pretrained(self,
net,
load_path,
strict=True,
param_key='params'):
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
load_net = torch.load(load_path,
map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info(
f'Loading: {param_key} does not exist, use params.')
if param_key in load_net:
load_net = load_net[param_key]
logger.info(
f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].'
)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
net.load_state_dict(load_net, strict=strict)
logger.info('load model done.')
return net
def forward(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError
def post_init(self):
"""
A method executed at the end of each model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
"""
self.init_weights()
def init_weights(self):
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = save_checkpoint,
config: Optional[dict] = None,
save_config_function: Callable = save_configuration,
**kwargs):
"""save the pretrained model, its configuration and other related files to a directory,
so that it can be re-loaded
Args:
target_folder (Union[str, os.PathLike]):
Directory to which to save. Will be created if it doesn't exist.
save_checkpoint_names (Union[str, List[str]]):
The checkpoint names to be saved in the target_folder
save_function (Callable, optional):
The function to use to save the state dictionary.
config (Optional[dict], optional):
The config for the configuration.json, might not be identical with model.config
save_config_function (Callble, optional):
The function to use to save the configuration.
"""
if config is None and hasattr(self, 'cfg'):
config = self.cfg
save_pretrained(self, target_folder, save_checkpoint_names,
save_function, **kwargs)
if config is not None:
save_config_function(target_folder, config)