This commit is contained in:
GuoxiaWang
2021-10-15 23:18:22 +08:00
parent ac456e6c88
commit 0c6ba50293
2 changed files with 95 additions and 93 deletions

View File

@@ -161,58 +161,59 @@ class Checkpoint(object):
else:
opt_state_dict[name] = tensor
meta_file = os.path.join(checkpoint_dir, 'meta.json')
if not os.path.exists(meta_file):
logging.error(
"Please make sure the checkpoint dir {} exists, and "
"parameters in that dir are validating.".format(
checkpoint_dir))
exit()
if classifier is not None and for_train:
meta_file = os.path.join(checkpoint_dir, 'meta.json')
if not os.path.exists(meta_file):
logging.error(
"Please make sure the checkpoint dir {} exists, and "
"parameters in that dir are validating.".format(
checkpoint_dir))
exit()
with open(meta_file, 'r') as handle:
extra_info = json.load(handle)
with open(meta_file, 'r') as handle:
extra_info = json.load(handle)
# Preporcess distributed parameters.
pretrain_world_size = extra_info['pretrain_world_size']
assert pretrain_world_size > 0
embedding_size = extra_info['embedding_size']
assert embedding_size == self.embedding_size
num_classes = extra_info['num_classes']
assert num_classes == self.num_classes
# Preporcess distributed parameters.
pretrain_world_size = extra_info['pretrain_world_size']
assert pretrain_world_size > 0
embedding_size = extra_info['embedding_size']
assert embedding_size == self.embedding_size
num_classes = extra_info['num_classes']
assert num_classes == self.num_classes
logging.info(
"Parameters for pre-training: pretrain_world_size ({}), "
"embedding_size ({}), and num_classes ({}).".format(
pretrain_world_size, embedding_size, num_classes))
logging.info("Parameters for inference or fine-tuning: "
"world_size ({}).".format(self.world_size))
logging.info(
"Parameters for pre-training: pretrain_world_size ({}), "
"embedding_size ({}), and num_classes ({}).".format(
pretrain_world_size, embedding_size, num_classes))
logging.info("Parameters for inference or fine-tuning: "
"world_size ({}).".format(self.world_size))
rank_str = '%05d' % self.rank
rank_str = '%05d' % self.rank
dist_weight_state_dict = rearrange_weight(
dist_weight_state_dict, pretrain_world_size, self.world_size)
dist_bias_state_dict = rearrange_weight(
dist_bias_state_dict, pretrain_world_size, self.world_size)
for name, value in dist_weight_state_dict.items():
if rank_str in name:
dist_param_state_dict[name] = value
for name, value in dist_bias_state_dict.items():
if rank_str in name:
dist_param_state_dict[name] = value
if for_train:
dist_weight_velocity_state_dict = rearrange_weight(
dist_weight_velocity_state_dict, pretrain_world_size,
self.world_size)
dist_bias_velocity_state_dict = rearrange_weight(
dist_bias_velocity_state_dict, pretrain_world_size,
self.world_size)
for name, value in dist_weight_velocity_state_dict.items():
dist_weight_state_dict = rearrange_weight(
dist_weight_state_dict, pretrain_world_size, self.world_size)
dist_bias_state_dict = rearrange_weight(
dist_bias_state_dict, pretrain_world_size, self.world_size)
for name, value in dist_weight_state_dict.items():
if rank_str in name:
opt_state_dict[name] = value
for name, value in dist_bias_velocity_state_dict.items():
dist_param_state_dict[name] = value
for name, value in dist_bias_state_dict.items():
if rank_str in name:
opt_state_dict[name] = value
dist_param_state_dict[name] = value
if for_train:
dist_weight_velocity_state_dict = rearrange_weight(
dist_weight_velocity_state_dict, pretrain_world_size,
self.world_size)
dist_bias_velocity_state_dict = rearrange_weight(
dist_bias_velocity_state_dict, pretrain_world_size,
self.world_size)
for name, value in dist_weight_velocity_state_dict.items():
if rank_str in name:
opt_state_dict[name] = value
for name, value in dist_bias_velocity_state_dict.items():
if rank_str in name:
opt_state_dict[name] = value
def map_actual_param_name(state_dict, load_state_dict):
for name, param in state_dict.items():
@@ -231,7 +232,7 @@ class Checkpoint(object):
assert optimizer is not None
optimizer.set_state_dict(opt_state_dict)
if for_train:
if classifier is not None and for_train:
return extra_info
else:
return {}

View File

@@ -135,59 +135,60 @@ class Checkpoint(object):
else:
state_dict[name] = tensor
meta_file = os.path.join(checkpoint_dir, 'meta.json')
if not os.path.exists(meta_file):
logging.error(
"Please make sure the checkpoint dir {} exists, and "
"parameters in that dir are validating.".format(
checkpoint_dir))
exit()
with open(meta_file, 'r') as handle:
extra_info = json.load(handle)
# Preporcess distributed parameters.
pretrain_world_size = extra_info['pretrain_world_size']
assert pretrain_world_size > 0
embedding_size = extra_info['embedding_size']
assert embedding_size == self.embedding_size
num_classes = extra_info['num_classes']
assert num_classes == self.num_classes
logging.info(
"Parameters for pre-training: pretrain_world_size ({}), "
"embedding_size ({}), and num_classes ({}).".format(
pretrain_world_size, embedding_size, num_classes))
logging.info("Parameters for inference or fine-tuning: "
"world_size ({}).".format(self.world_size))
rank_str = '%05d' % self.rank
dist_weight_state_dict = rearrange_weight(
dist_weight_state_dict, pretrain_world_size, self.world_size)
dist_bias_state_dict = rearrange_weight(
dist_bias_state_dict, pretrain_world_size, self.world_size)
for name, value in dist_weight_state_dict.items():
if rank_str in name:
state_dict[name] = value
for name, value in dist_bias_state_dict.items():
if rank_str in name:
state_dict[name] = value
if for_train:
dist_weight_velocity_state_dict = rearrange_weight(
dist_weight_velocity_state_dict, pretrain_world_size,
self.world_size)
dist_bias_velocity_state_dict = rearrange_weight(
dist_bias_velocity_state_dict, pretrain_world_size,
self.world_size)
for name, value in dist_weight_velocity_state_dict.items():
meta_file = os.path.join(checkpoint_dir, 'meta.json')
if not os.path.exists(meta_file):
logging.error(
"Please make sure the checkpoint dir {} exists, and "
"parameters in that dir are validating.".format(
checkpoint_dir))
exit()
with open(meta_file, 'r') as handle:
extra_info = json.load(handle)
# Preporcess distributed parameters.
pretrain_world_size = extra_info['pretrain_world_size']
assert pretrain_world_size > 0
embedding_size = extra_info['embedding_size']
assert embedding_size == self.embedding_size
num_classes = extra_info['num_classes']
assert num_classes == self.num_classes
logging.info(
"Parameters for pre-training: pretrain_world_size ({}), "
"embedding_size ({}), and num_classes ({}).".format(
pretrain_world_size, embedding_size, num_classes))
logging.info("Parameters for inference or fine-tuning: "
"world_size ({}).".format(self.world_size))
rank_str = '%05d' % self.rank
dist_weight_state_dict = rearrange_weight(
dist_weight_state_dict, pretrain_world_size, self.world_size)
dist_bias_state_dict = rearrange_weight(
dist_bias_state_dict, pretrain_world_size, self.world_size)
for name, value in dist_weight_state_dict.items():
if rank_str in name:
state_dict[name] = value
for name, value in dist_bias_velocity_state_dict.items():
for name, value in dist_bias_state_dict.items():
if rank_str in name:
state_dict[name] = value
if for_train:
dist_weight_velocity_state_dict = rearrange_weight(
dist_weight_velocity_state_dict, pretrain_world_size,
self.world_size)
dist_bias_velocity_state_dict = rearrange_weight(
dist_bias_velocity_state_dict, pretrain_world_size,
self.world_size)
for name, value in dist_weight_velocity_state_dict.items():
if rank_str in name:
state_dict[name] = value
for name, value in dist_bias_velocity_state_dict.items():
if rank_str in name:
state_dict[name] = value
program.set_state_dict(state_dict)
logging.info("Load checkpoint from '{}'. ".format(checkpoint_dir))
if for_train: