mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2026-04-23 12:40:15 +00:00
117 lines
4.4 KiB
Python
Executable File
117 lines
4.4 KiB
Python
Executable File
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
import torch
|
|
|
|
|
|
def load_checkpoint(model,
|
|
load_dir,
|
|
tag,
|
|
load_module_strict=True,
|
|
load_optimizer_states=True,
|
|
load_lr_scheduler_states=True):
|
|
r"""Load training checkpoint
|
|
|
|
Arguments:
|
|
load_dir: Required. Directory to load the checkpoint from
|
|
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
|
|
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and
|
|
checkpoint match.
|
|
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint.
|
|
Ex. ADAM's momentum and variance
|
|
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
|
|
Return:
|
|
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
|
|
client_state: State dictionary used for loading required training states in the client code.
|
|
"""
|
|
|
|
load_path, client_states = _load_checkpoint(
|
|
model,
|
|
load_dir,
|
|
tag,
|
|
load_module_strict=load_module_strict,
|
|
load_optimizer_states=load_optimizer_states,
|
|
load_lr_scheduler_states=load_lr_scheduler_states)
|
|
|
|
if load_optimizer_states:
|
|
if model.zero_optimization() and load_path is not None:
|
|
model._load_zero_checkpoint(
|
|
load_dir, tag, load_optimizer_states=load_optimizer_states)
|
|
|
|
return load_path, client_states
|
|
|
|
|
|
def _get_ckpt_name(mp_rank, checkpoints_path, tag):
|
|
ckpt_name = os.path.join(
|
|
checkpoints_path, str(tag),
|
|
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
|
|
return ckpt_name
|
|
|
|
|
|
def pre_load(mp_rank, load_dir, tag=''):
|
|
load_path = _get_ckpt_name(mp_rank, load_dir, tag)
|
|
checkpoint = torch.load(load_path,
|
|
map_location=lambda storage, loc: storage)
|
|
return checkpoint['module']
|
|
|
|
|
|
def _load_checkpoint(model,
|
|
load_dir,
|
|
tag,
|
|
load_module_strict=True,
|
|
load_optimizer_states=True,
|
|
load_lr_scheduler_states=True):
|
|
|
|
load_path = model._get_ckpt_name(load_dir, tag)
|
|
|
|
if not os.path.exists(load_path):
|
|
return None, None
|
|
|
|
checkpoint = torch.load(load_path,
|
|
map_location=lambda storage, loc: storage)
|
|
|
|
model.load_module_state_dict(state_dict=checkpoint['module'],
|
|
strict=load_module_strict)
|
|
if not model.zero_optimization() and load_optimizer_states:
|
|
if model.fp16_enabled():
|
|
model.optimizer.load_state_dict(
|
|
checkpoint['optimizer'],
|
|
load_optimizer_states=load_optimizer_states)
|
|
elif load_optimizer_states:
|
|
model.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
|
|
if load_lr_scheduler_states and model.lr_scheduler is not None:
|
|
model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
|
|
|
model.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
|
|
model.global_steps = checkpoint['global_steps']
|
|
model.global_samples = checkpoint.get(
|
|
'global_samples', model.global_steps * model.train_batch_size())
|
|
model.skipped_steps = checkpoint['skipped_steps']
|
|
model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
|
|
model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
|
|
deepspeed_states = [
|
|
'module', 'optimizer', 'lr_scheduler', 'csr_tensor_module_names',
|
|
'skipped_steps', 'global_steps', 'dp_world_size', 'mp_world_size'
|
|
]
|
|
client_state = {
|
|
key: value
|
|
for key, value in checkpoint.items() if key not in deepspeed_states
|
|
}
|
|
|
|
return load_path, client_state
|