mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2026-03-11 07:30:16 +00:00
587 lines
24 KiB
Python
587 lines
24 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import io
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
from collections import OrderedDict
|
|
from shutil import copytree, ignore_patterns, rmtree
|
|
from typing import Callable, Dict, Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
from modelscope import __version__
|
|
from modelscope.fileio import File, LocalStorage
|
|
from modelscope.utils.config import JSONIteratorEncoder
|
|
from modelscope.utils.constant import ConfigFields, ModelFile
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
storage = LocalStorage()
|
|
|
|
|
|
def weights_to_cpu(state_dict):
|
|
"""Copy a model state_dict to cpu.
|
|
|
|
Args:
|
|
state_dict (OrderedDict): Model weights on GPU.
|
|
|
|
Returns:
|
|
OrderedDict: Model weights on GPU.
|
|
"""
|
|
state_dict_cpu = OrderedDict()
|
|
for key, val in state_dict.items():
|
|
state_dict_cpu[key] = val.cpu()
|
|
# Keep metadata in state_dict
|
|
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
|
|
return state_dict_cpu
|
|
|
|
|
|
def save_checkpoint(model: torch.nn.Module,
|
|
filename: str,
|
|
optimizer: Optional[Optimizer] = None,
|
|
lr_scheduler: Optional[_LRScheduler] = None,
|
|
meta: Optional[dict] = None,
|
|
with_meta: bool = True) -> None:
|
|
"""Save checkpoint to file.
|
|
|
|
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
|
``optimizer``. By default, ``meta`` will contain version and time info.
|
|
|
|
Args:
|
|
model (Module): Module whose params are to be saved.
|
|
filename (str): Checkpoint filename.
|
|
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
|
lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved.
|
|
meta (dict, optional): Metadata to be saved in checkpoint.
|
|
with_meta (bool, optional):
|
|
"""
|
|
if meta is None:
|
|
meta = {}
|
|
elif not isinstance(meta, dict):
|
|
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
|
meta.update(modelscope=__version__, time=time.asctime())
|
|
|
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
|
model = model.module
|
|
|
|
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
|
# save class name to the meta
|
|
meta.update(CLASSES=model.CLASSES)
|
|
|
|
if with_meta:
|
|
checkpoint = {
|
|
'meta': meta,
|
|
'state_dict': weights_to_cpu(model.state_dict())
|
|
}
|
|
|
|
# save optimizer state dict in the checkpoint
|
|
if isinstance(optimizer, Optimizer):
|
|
checkpoint['optimizer'] = optimizer.state_dict()
|
|
elif isinstance(optimizer, dict):
|
|
checkpoint['optimizer'] = {}
|
|
for name, optim in optimizer.items():
|
|
checkpoint['optimizer'][name] = optim.state_dict()
|
|
|
|
# save lr_scheduler state dict in the checkpoint
|
|
if lr_scheduler is not None and hasattr(lr_scheduler, 'state_dict'):
|
|
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
|
else:
|
|
checkpoint = weights_to_cpu(model.state_dict())
|
|
|
|
with io.BytesIO() as f:
|
|
torch.save(checkpoint, f)
|
|
File.write(f.getvalue(), filename)
|
|
|
|
|
|
def load_checkpoint(filename,
|
|
model,
|
|
optimizer: Optimizer = None,
|
|
lr_scheduler: _LRScheduler = None):
|
|
if not os.path.exists(filename):
|
|
raise ValueError(f'Checkpoint file {filename} does not exist!')
|
|
checkpoint = torch.load(filename, map_location='cpu')
|
|
|
|
if optimizer is not None:
|
|
if 'optimizer' in checkpoint:
|
|
if isinstance(optimizer, Optimizer):
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
elif isinstance(optimizer, dict):
|
|
optimizer_dict = checkpoint['optimizer']
|
|
for key, optimizer_ins in optimizer.items():
|
|
if key in optimizer_dict:
|
|
optimizer_ins.load_state_dict(optimizer_dict[key])
|
|
else:
|
|
logger.warning(
|
|
f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}'
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f'The state dict of optimizer cannot be found in checkpoint file: {filename}'
|
|
)
|
|
|
|
if lr_scheduler is not None:
|
|
if 'lr_scheduler' in checkpoint:
|
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
|
else:
|
|
logger.warning(
|
|
f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}'
|
|
)
|
|
|
|
state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[
|
|
'state_dict']
|
|
model.load_state_dict(state_dict)
|
|
return checkpoint.get('meta', {})
|
|
|
|
|
|
def load_task_model_checkpoint(model_to_load,
|
|
model_local_dir,
|
|
default_dtype=None,
|
|
load_state_fn=None,
|
|
**kwargs):
|
|
"""
|
|
Load model checkpoint file and feed the parameters into the model.
|
|
Args:
|
|
model_to_load: The model to be load
|
|
model_local_dir: The actual checkpoint dir on local disk.
|
|
default_dtype: Set the default float type by 'torch.set_default_dtype'
|
|
load_state_fn: An optional load_state_fn used to load state_dict into the model.
|
|
|
|
Returns:
|
|
|
|
"""
|
|
def _add_head_prefix_to_state_dict(state_dicts, head_prefix,
|
|
expected_keys_without_head_prefix,
|
|
missing_keys):
|
|
new_state_dict = OrderedDict()
|
|
for name, module in state_dicts.items():
|
|
if name in expected_keys_without_head_prefix:
|
|
name_with_head = '.'.join([head_prefix, name])
|
|
new_state_dict[name_with_head] = module
|
|
expected_keys_without_head_prefix.remove(name)
|
|
missing_keys = list(set(missing_keys) - set([name_with_head]))
|
|
else:
|
|
new_state_dict[name] = module
|
|
|
|
missing_head_keys = []
|
|
if len(expected_keys_without_head_prefix) > 0:
|
|
missing_head_keys = expected_keys_without_head_prefix.copy()
|
|
return new_state_dict, missing_head_keys, missing_keys
|
|
|
|
def _find_mismatched_keys(
|
|
state_dicts,
|
|
model_state_dict,
|
|
loaded_keys,
|
|
prefix,
|
|
add_prefix_to_model,
|
|
remove_prefix_from_model,
|
|
ignore_mismatched_sizes,
|
|
):
|
|
mismatched_key = []
|
|
if ignore_mismatched_sizes:
|
|
for checkpoint_key in loaded_keys:
|
|
model_key = checkpoint_key
|
|
if remove_prefix_from_model:
|
|
# The model key starts with `prefix` but `checkpoint_key` doesn't, so we add it.
|
|
model_key = f'{prefix}.{checkpoint_key}'
|
|
elif add_prefix_to_model:
|
|
# The model key doesn't start with `prefix` but `checkpoint_key` does, so we remove it.
|
|
model_key = '.'.join(checkpoint_key.split('.')[1:])
|
|
|
|
if model_key in model_state_dict:
|
|
model_shape = model_state_dict[model_key].shape
|
|
checkpoint_shape = state_dicts[checkpoint_key].shape
|
|
if checkpoint_shape != model_shape:
|
|
mismatched_key.append(
|
|
(checkpoint_key, state_dicts[checkpoint_key].shape,
|
|
model_state_dict[model_key].shape))
|
|
del state_dicts[checkpoint_key]
|
|
return mismatched_key
|
|
|
|
def _load_state_dict_into_model(
|
|
model,
|
|
state_dict,
|
|
start_prefix,
|
|
head_prefix_keys,
|
|
load_state_fn=None,
|
|
):
|
|
# Convert old format to new format if needed from a PyTorch state_dict
|
|
old_keys = []
|
|
new_keys = []
|
|
for key in state_dict.keys():
|
|
new_key = None
|
|
if 'gamma' in key:
|
|
new_key = key.replace('gamma', 'weight')
|
|
if 'beta' in key:
|
|
new_key = key.replace('beta', 'bias')
|
|
if new_key:
|
|
old_keys.append(key)
|
|
new_keys.append(new_key)
|
|
for old_key, new_key in zip(old_keys, new_keys):
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
metadata = getattr(state_dict, '_metadata', None)
|
|
state_dict = state_dict.copy()
|
|
if metadata is not None:
|
|
state_dict._metadata = metadata
|
|
|
|
error_msgs = []
|
|
|
|
if load_state_fn is not None:
|
|
load_state_fn(model,
|
|
state_dict,
|
|
prefix=start_prefix,
|
|
head_prefix_keys=head_prefix_keys,
|
|
local_metadata=None,
|
|
error_msgs=error_msgs)
|
|
else:
|
|
|
|
def load(module: nn.Module, prefix=''):
|
|
local_metadata = {} if metadata is None else metadata.get(
|
|
prefix[:-1], {})
|
|
args = (state_dict, prefix, local_metadata, True, [], [],
|
|
error_msgs)
|
|
module._load_from_state_dict(*args)
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
load(child, prefix + name + '.')
|
|
|
|
load(model, prefix=start_prefix)
|
|
|
|
return error_msgs
|
|
|
|
def _load_checkpoint(
|
|
model,
|
|
state_dict,
|
|
load_state_fn,
|
|
ignore_mismatched_sizes,
|
|
_fast_init,
|
|
):
|
|
# Retrieve missing & unexpected_keys
|
|
model_state_dict = model.state_dict()
|
|
expected_keys = list(model_state_dict.keys())
|
|
keys_from_pretrained = list(state_dict.keys())
|
|
|
|
prefix = model.base_model_prefix
|
|
|
|
# during loading stage, base model prefix is complicated, should consider remove or add
|
|
if len(prefix) > 0:
|
|
# nlp: encoder, decoder
|
|
pretrained_has_prefix_module = any(
|
|
s.startswith(prefix) for s in keys_from_pretrained)
|
|
model_expects_prefix_module = any(
|
|
s.startswith(prefix) for s in expected_keys)
|
|
else:
|
|
# nlp:encoder-decoder, cv:backbone-head,
|
|
pretrained_has_prefix_module = False
|
|
model_expects_prefix_module = False
|
|
|
|
remove_prefix_from_model = not pretrained_has_prefix_module and model_expects_prefix_module
|
|
add_prefix_to_model = pretrained_has_prefix_module and not model_expects_prefix_module
|
|
|
|
if remove_prefix_from_model:
|
|
expected_keys_not_base_model_prefixed = [
|
|
s for s in expected_keys if not s.startswith(prefix)
|
|
]
|
|
expected_keys = [
|
|
'.'.join(s.split('.')[1:]) if s.startswith(prefix) else s
|
|
for s in expected_keys
|
|
]
|
|
elif add_prefix_to_model:
|
|
# backbone only
|
|
expected_keys = ['.'.join([prefix, s]) for s in expected_keys]
|
|
expected_keys_not_base_model_prefixed = []
|
|
|
|
missing_keys = list(set(expected_keys) - set(keys_from_pretrained))
|
|
unexpected_keys = list(set(keys_from_pretrained) - set(expected_keys))
|
|
|
|
# during loading stage head prefix is simple, add or not add
|
|
prefix_heads = model.head_prefix
|
|
expected_head_keys_without_head_prefix = []
|
|
missing_head_keys = []
|
|
unexpected_head_keys = []
|
|
pretrained_has_prefix_head = dict()
|
|
head_prefix_keys = dict()
|
|
|
|
# only for case of head mismatched with state-dict
|
|
if len(prefix_heads) > 0 and len(unexpected_keys) > 0:
|
|
if isinstance(prefix_heads, str):
|
|
prefix_heads = [prefix_heads]
|
|
|
|
# to double-check if head matched with state-dict
|
|
for prefix_head in prefix_heads:
|
|
pretrained_has_prefix_head[prefix_head] = any(
|
|
s.startswith(prefix_head) for s in keys_from_pretrained)
|
|
|
|
for prefix_head in prefix_heads:
|
|
expected_keys_without_head_prefix = [
|
|
'.'.join(s.split('.')[1:]) for s in expected_keys
|
|
if s.startswith(prefix_head)
|
|
]
|
|
expected_head_keys_without_head_prefix.extend(
|
|
expected_keys_without_head_prefix)
|
|
head_prefix_keys[
|
|
prefix_head] = expected_keys_without_head_prefix
|
|
unexpected_head_keys = list(
|
|
set(unexpected_keys) -
|
|
set(expected_head_keys_without_head_prefix))
|
|
unexpected_keys = list(
|
|
set(unexpected_keys) -
|
|
set(expected_head_keys_without_head_prefix))
|
|
|
|
_keys_to_ignore_on_load_missing = kwargs.pop(
|
|
'_keys_to_ignore_on_load_missing', None)
|
|
_keys_to_ignore_on_load_unexpected = kwargs.pop(
|
|
'_keys_to_ignore_on_load_unexpected', None)
|
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
|
# the user.
|
|
if _keys_to_ignore_on_load_missing is not None:
|
|
for pat in _keys_to_ignore_on_load_missing:
|
|
missing_keys = [
|
|
k for k in missing_keys if re.search(pat, k) is None
|
|
]
|
|
|
|
if _keys_to_ignore_on_load_unexpected is not None:
|
|
for pat in _keys_to_ignore_on_load_unexpected:
|
|
unexpected_keys = [
|
|
k for k in unexpected_keys if re.search(pat, k) is None
|
|
]
|
|
|
|
# retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
|
|
if _fast_init:
|
|
uninitialized_modules = retrieve_modules_from_names(
|
|
model,
|
|
missing_keys,
|
|
prefix=prefix,
|
|
add_prefix=add_prefix_to_model,
|
|
remove_prefix=remove_prefix_from_model)
|
|
for module in uninitialized_modules:
|
|
model._init_weights(module)
|
|
|
|
# Make sure we are able to load head correctly by revise state-dict
|
|
missing_head_keys_by_head = dict()
|
|
if len(head_prefix_keys) > 0:
|
|
for head_prefix in head_prefix_keys:
|
|
if not pretrained_has_prefix_head[head_prefix]:
|
|
state_dict, missing_head_keys, missing_keys = _add_head_prefix_to_state_dict(
|
|
state_dict, head_prefix, head_prefix_keys[head_prefix],
|
|
missing_keys)
|
|
missing_head_keys_by_head[head_prefix] = missing_head_keys
|
|
|
|
# Make sure we are able to load base models as well as derived models (with heads)
|
|
start_prefix = ''
|
|
model_to_load = model
|
|
heads_to_load = dict()
|
|
if len(model.base_model_prefix) > 0 and not hasattr(
|
|
model,
|
|
model.base_model_prefix) and pretrained_has_prefix_module:
|
|
start_prefix = model.base_model_prefix + '.'
|
|
if len(model.base_model_prefix) > 0 and hasattr(
|
|
model,
|
|
model.base_model_prefix) and not pretrained_has_prefix_module:
|
|
model_to_load = getattr(model, model.base_model_prefix)
|
|
for head_prefix in prefix_heads:
|
|
heads_to_load[head_prefix] = getattr(model, head_prefix)
|
|
if any(key in expected_keys_not_base_model_prefixed
|
|
for key in keys_from_pretrained):
|
|
raise ValueError(
|
|
'The state dictionary of the model you are trying to load is corrupted. Are you sure it was '
|
|
'properly saved?')
|
|
|
|
# Whole checkpoint
|
|
mismatched_keys = _find_mismatched_keys(
|
|
state_dict,
|
|
model_state_dict,
|
|
keys_from_pretrained,
|
|
prefix,
|
|
add_prefix_to_model,
|
|
remove_prefix_from_model,
|
|
ignore_mismatched_sizes,
|
|
)
|
|
error_msgs = _load_state_dict_into_model(model_to_load, state_dict,
|
|
start_prefix, load_state_fn)
|
|
|
|
if len(heads_to_load) > 0:
|
|
for head in heads_to_load:
|
|
local_error_msgs = _load_state_dict_into_model(
|
|
heads_to_load[head], state_dict, head + '.', load_state_fn)
|
|
error_msgs.extend(local_error_msgs)
|
|
|
|
if len(error_msgs) > 0:
|
|
error_msg = '\n\t'.join(error_msgs)
|
|
raise RuntimeError(
|
|
f'Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}'
|
|
)
|
|
|
|
if len(unexpected_keys) > 0:
|
|
logger.warning(
|
|
f'Some weights of the model checkpoint were not used when'
|
|
f' initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are'
|
|
f' initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or'
|
|
' with another architecture (e.g. initializing a BertForTokenClassification model from a'
|
|
' BertForPreTraining model).\n- This IS NOT expected if you are initializing'
|
|
f' {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical'
|
|
' (initializing a BertForTokenClassification model from a BertForTokenClassification model).'
|
|
)
|
|
elif len(unexpected_head_keys) > 0:
|
|
logger.warning(
|
|
f'Some weights of the model checkpoint were not used when'
|
|
f' initializing {model.__class__.__name__}: {unexpected_head_keys}\n- This IS Not expected if you are'
|
|
f' initializing {model.__class__.__name__} from the checkpoint of a model with a same task while the'
|
|
' structure is different (e.g. initializing a BertForTokenClassification model from a'
|
|
' BertForTokenClassification model).')
|
|
else:
|
|
logger.info(
|
|
f'All model checkpoint weights were used when initializing {model.__class__.__name__}.\n'
|
|
)
|
|
if len(missing_keys) > 0:
|
|
logger.warning(
|
|
f'Some weights of {model.__class__.__name__} were not initialized from the model checkpoint'
|
|
f' and are newly initialized: {missing_keys}\nYou should probably'
|
|
' TRAIN this model on a down-stream task to be able to use it for predictions and inference.'
|
|
)
|
|
elif len(mismatched_keys) == 0:
|
|
logger.info(
|
|
f'All the weights of {model.__class__.__name__} were initialized from the model checkpoint '
|
|
f'If your task is similar to the task the model of the checkpoint'
|
|
f' was trained on, you can already use {model.__class__.__name__} for predictions without further'
|
|
' training.')
|
|
if len(mismatched_keys) > 0:
|
|
mismatched_warning = '\n'.join([
|
|
f'- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated'
|
|
for key, shape1, shape2 in mismatched_keys
|
|
])
|
|
logger.warning(
|
|
f'Some weights of {model.__class__.__name__} were not initialized from the model checkpoint'
|
|
f' and are newly initialized because the shapes did not'
|
|
f' match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able'
|
|
' to use it for predictions and inference.')
|
|
|
|
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
|
|
|
def retrieve_modules_from_names(model,
|
|
names,
|
|
prefix=None,
|
|
add_prefix=False,
|
|
remove_prefix=False):
|
|
module_keys = set(['.'.join(key.split('.')[:-1]) for key in names])
|
|
|
|
# torch.nn.ParameterList is a special case where two parameter keywords
|
|
# are appended to the module name, *e.g.* bert.special_embeddings.0
|
|
module_keys = module_keys.union(
|
|
set([
|
|
'.'.join(key.split('.')[:-2]) for key in names
|
|
if key[-1].isdigit()
|
|
]))
|
|
|
|
retrieved_modules = []
|
|
# retrieve all modules that has at least one missing weight name
|
|
for name, module in model.named_modules():
|
|
if remove_prefix:
|
|
name = '.'.join(
|
|
name.split('.')[1:]) if name.startswith(prefix) else name
|
|
elif add_prefix:
|
|
name = '.'.join([prefix, name]) if len(name) > 0 else prefix
|
|
|
|
if name in module_keys:
|
|
retrieved_modules.append(module)
|
|
|
|
return retrieved_modules
|
|
|
|
# TODO Sharded ckpt
|
|
ckpt_file = os.path.join(model_local_dir, ModelFile.TORCH_MODEL_BIN_FILE)
|
|
state_dict = torch.load(ckpt_file, map_location='cpu')
|
|
if default_dtype is not None:
|
|
torch.set_default_dtype(default_dtype)
|
|
|
|
missing_keys, unexpected_keys, mismatched_keys, error_msgs = _load_checkpoint(
|
|
model_to_load,
|
|
state_dict,
|
|
load_state_fn=load_state_fn,
|
|
ignore_mismatched_sizes=True,
|
|
_fast_init=True,
|
|
)
|
|
|
|
return {
|
|
'model': model_to_load,
|
|
'missing_keys': missing_keys,
|
|
'unexpected_keys': unexpected_keys,
|
|
'mismatched_keys': mismatched_keys,
|
|
'error_msgs': error_msgs,
|
|
}
|
|
|
|
|
|
def save_configuration(target_folder, config: Dict):
|
|
from modelscope.utils.config import Config
|
|
if isinstance(config, Config):
|
|
config = config.to_dict()
|
|
if ConfigFields.pipeline not in config:
|
|
config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]}
|
|
cfg_str = json.dumps(config, indent=4, cls=JSONIteratorEncoder)
|
|
config_file = os.path.join(target_folder, ModelFile.CONFIGURATION)
|
|
storage.write(cfg_str.encode(), config_file)
|
|
|
|
|
|
def save_pretrained(model,
|
|
target_folder: Union[str, os.PathLike],
|
|
save_checkpoint_name: str = None,
|
|
save_function: Callable = None,
|
|
**kwargs):
|
|
"""save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded
|
|
|
|
Args:
|
|
model (Model): Model whose params are to be saved.
|
|
|
|
target_folder (Union[str, os.PathLike]):
|
|
Directory to which to save. Will be created if it doesn't exist.
|
|
|
|
save_checkpoint_name (str):
|
|
The checkpoint name to be saved in the target_folder
|
|
|
|
save_function (Callable):
|
|
The function to use to save the state dictionary.
|
|
"""
|
|
|
|
if save_function is None or not isinstance(save_function, Callable):
|
|
raise Exception('A valid save function must be passed in')
|
|
|
|
if target_folder is None or os.path.isfile(target_folder):
|
|
raise ValueError(
|
|
f'Provided path ({target_folder}) should be a directory, not a file'
|
|
)
|
|
|
|
if save_checkpoint_name is None:
|
|
raise Exception(
|
|
'At least pass in one checkpoint name for saving method')
|
|
|
|
# Clean the folder from a previous save
|
|
if os.path.exists(target_folder):
|
|
rmtree(target_folder)
|
|
|
|
# Single ckpt path, sharded ckpt logic will be added later
|
|
output_ckpt_path = os.path.join(target_folder, save_checkpoint_name)
|
|
|
|
# Save the files to be copied to the save directory, ignore the original ckpts and configuration
|
|
origin_file_to_be_ignored = [save_checkpoint_name]
|
|
ignore_file_set = set(origin_file_to_be_ignored)
|
|
ignore_file_set.add(ModelFile.CONFIGURATION)
|
|
ignore_file_set.add('.*')
|
|
if hasattr(model, 'model_dir') and model.model_dir is not None:
|
|
copytree(model.model_dir,
|
|
target_folder,
|
|
ignore=ignore_patterns(*ignore_file_set))
|
|
|
|
# Save the ckpt to the save directory
|
|
try:
|
|
save_function(model, output_ckpt_path, **kwargs)
|
|
except Exception as e:
|
|
raise Exception(
|
|
f'During saving checkpoints, the error of "{type(e).__name__} '
|
|
f'with msg {e} thrown')
|