mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
fix bug
This commit is contained in:
@@ -161,58 +161,59 @@ class Checkpoint(object):
|
||||
else:
|
||||
opt_state_dict[name] = tensor
|
||||
|
||||
meta_file = os.path.join(checkpoint_dir, 'meta.json')
|
||||
if not os.path.exists(meta_file):
|
||||
logging.error(
|
||||
"Please make sure the checkpoint dir {} exists, and "
|
||||
"parameters in that dir are validating.".format(
|
||||
checkpoint_dir))
|
||||
exit()
|
||||
if classifier is not None and for_train:
|
||||
meta_file = os.path.join(checkpoint_dir, 'meta.json')
|
||||
if not os.path.exists(meta_file):
|
||||
logging.error(
|
||||
"Please make sure the checkpoint dir {} exists, and "
|
||||
"parameters in that dir are validating.".format(
|
||||
checkpoint_dir))
|
||||
exit()
|
||||
|
||||
with open(meta_file, 'r') as handle:
|
||||
extra_info = json.load(handle)
|
||||
with open(meta_file, 'r') as handle:
|
||||
extra_info = json.load(handle)
|
||||
|
||||
# Preporcess distributed parameters.
|
||||
pretrain_world_size = extra_info['pretrain_world_size']
|
||||
assert pretrain_world_size > 0
|
||||
embedding_size = extra_info['embedding_size']
|
||||
assert embedding_size == self.embedding_size
|
||||
num_classes = extra_info['num_classes']
|
||||
assert num_classes == self.num_classes
|
||||
# Preporcess distributed parameters.
|
||||
pretrain_world_size = extra_info['pretrain_world_size']
|
||||
assert pretrain_world_size > 0
|
||||
embedding_size = extra_info['embedding_size']
|
||||
assert embedding_size == self.embedding_size
|
||||
num_classes = extra_info['num_classes']
|
||||
assert num_classes == self.num_classes
|
||||
|
||||
logging.info(
|
||||
"Parameters for pre-training: pretrain_world_size ({}), "
|
||||
"embedding_size ({}), and num_classes ({}).".format(
|
||||
pretrain_world_size, embedding_size, num_classes))
|
||||
logging.info("Parameters for inference or fine-tuning: "
|
||||
"world_size ({}).".format(self.world_size))
|
||||
logging.info(
|
||||
"Parameters for pre-training: pretrain_world_size ({}), "
|
||||
"embedding_size ({}), and num_classes ({}).".format(
|
||||
pretrain_world_size, embedding_size, num_classes))
|
||||
logging.info("Parameters for inference or fine-tuning: "
|
||||
"world_size ({}).".format(self.world_size))
|
||||
|
||||
rank_str = '%05d' % self.rank
|
||||
rank_str = '%05d' % self.rank
|
||||
|
||||
dist_weight_state_dict = rearrange_weight(
|
||||
dist_weight_state_dict, pretrain_world_size, self.world_size)
|
||||
dist_bias_state_dict = rearrange_weight(
|
||||
dist_bias_state_dict, pretrain_world_size, self.world_size)
|
||||
for name, value in dist_weight_state_dict.items():
|
||||
if rank_str in name:
|
||||
dist_param_state_dict[name] = value
|
||||
for name, value in dist_bias_state_dict.items():
|
||||
if rank_str in name:
|
||||
dist_param_state_dict[name] = value
|
||||
|
||||
if for_train:
|
||||
dist_weight_velocity_state_dict = rearrange_weight(
|
||||
dist_weight_velocity_state_dict, pretrain_world_size,
|
||||
self.world_size)
|
||||
dist_bias_velocity_state_dict = rearrange_weight(
|
||||
dist_bias_velocity_state_dict, pretrain_world_size,
|
||||
self.world_size)
|
||||
for name, value in dist_weight_velocity_state_dict.items():
|
||||
dist_weight_state_dict = rearrange_weight(
|
||||
dist_weight_state_dict, pretrain_world_size, self.world_size)
|
||||
dist_bias_state_dict = rearrange_weight(
|
||||
dist_bias_state_dict, pretrain_world_size, self.world_size)
|
||||
for name, value in dist_weight_state_dict.items():
|
||||
if rank_str in name:
|
||||
opt_state_dict[name] = value
|
||||
for name, value in dist_bias_velocity_state_dict.items():
|
||||
dist_param_state_dict[name] = value
|
||||
for name, value in dist_bias_state_dict.items():
|
||||
if rank_str in name:
|
||||
opt_state_dict[name] = value
|
||||
dist_param_state_dict[name] = value
|
||||
|
||||
if for_train:
|
||||
dist_weight_velocity_state_dict = rearrange_weight(
|
||||
dist_weight_velocity_state_dict, pretrain_world_size,
|
||||
self.world_size)
|
||||
dist_bias_velocity_state_dict = rearrange_weight(
|
||||
dist_bias_velocity_state_dict, pretrain_world_size,
|
||||
self.world_size)
|
||||
for name, value in dist_weight_velocity_state_dict.items():
|
||||
if rank_str in name:
|
||||
opt_state_dict[name] = value
|
||||
for name, value in dist_bias_velocity_state_dict.items():
|
||||
if rank_str in name:
|
||||
opt_state_dict[name] = value
|
||||
|
||||
def map_actual_param_name(state_dict, load_state_dict):
|
||||
for name, param in state_dict.items():
|
||||
@@ -231,7 +232,7 @@ class Checkpoint(object):
|
||||
assert optimizer is not None
|
||||
optimizer.set_state_dict(opt_state_dict)
|
||||
|
||||
if for_train:
|
||||
if classifier is not None and for_train:
|
||||
return extra_info
|
||||
else:
|
||||
return {}
|
||||
|
||||
@@ -135,59 +135,60 @@ class Checkpoint(object):
|
||||
else:
|
||||
state_dict[name] = tensor
|
||||
|
||||
meta_file = os.path.join(checkpoint_dir, 'meta.json')
|
||||
if not os.path.exists(meta_file):
|
||||
logging.error(
|
||||
"Please make sure the checkpoint dir {} exists, and "
|
||||
"parameters in that dir are validating.".format(
|
||||
checkpoint_dir))
|
||||
exit()
|
||||
|
||||
with open(meta_file, 'r') as handle:
|
||||
extra_info = json.load(handle)
|
||||
|
||||
# Preporcess distributed parameters.
|
||||
pretrain_world_size = extra_info['pretrain_world_size']
|
||||
assert pretrain_world_size > 0
|
||||
embedding_size = extra_info['embedding_size']
|
||||
assert embedding_size == self.embedding_size
|
||||
num_classes = extra_info['num_classes']
|
||||
assert num_classes == self.num_classes
|
||||
|
||||
logging.info(
|
||||
"Parameters for pre-training: pretrain_world_size ({}), "
|
||||
"embedding_size ({}), and num_classes ({}).".format(
|
||||
pretrain_world_size, embedding_size, num_classes))
|
||||
logging.info("Parameters for inference or fine-tuning: "
|
||||
"world_size ({}).".format(self.world_size))
|
||||
|
||||
rank_str = '%05d' % self.rank
|
||||
|
||||
dist_weight_state_dict = rearrange_weight(
|
||||
dist_weight_state_dict, pretrain_world_size, self.world_size)
|
||||
dist_bias_state_dict = rearrange_weight(
|
||||
dist_bias_state_dict, pretrain_world_size, self.world_size)
|
||||
for name, value in dist_weight_state_dict.items():
|
||||
if rank_str in name:
|
||||
state_dict[name] = value
|
||||
for name, value in dist_bias_state_dict.items():
|
||||
if rank_str in name:
|
||||
state_dict[name] = value
|
||||
|
||||
if for_train:
|
||||
dist_weight_velocity_state_dict = rearrange_weight(
|
||||
dist_weight_velocity_state_dict, pretrain_world_size,
|
||||
self.world_size)
|
||||
dist_bias_velocity_state_dict = rearrange_weight(
|
||||
dist_bias_velocity_state_dict, pretrain_world_size,
|
||||
self.world_size)
|
||||
for name, value in dist_weight_velocity_state_dict.items():
|
||||
meta_file = os.path.join(checkpoint_dir, 'meta.json')
|
||||
if not os.path.exists(meta_file):
|
||||
logging.error(
|
||||
"Please make sure the checkpoint dir {} exists, and "
|
||||
"parameters in that dir are validating.".format(
|
||||
checkpoint_dir))
|
||||
exit()
|
||||
|
||||
with open(meta_file, 'r') as handle:
|
||||
extra_info = json.load(handle)
|
||||
|
||||
# Preporcess distributed parameters.
|
||||
pretrain_world_size = extra_info['pretrain_world_size']
|
||||
assert pretrain_world_size > 0
|
||||
embedding_size = extra_info['embedding_size']
|
||||
assert embedding_size == self.embedding_size
|
||||
num_classes = extra_info['num_classes']
|
||||
assert num_classes == self.num_classes
|
||||
|
||||
logging.info(
|
||||
"Parameters for pre-training: pretrain_world_size ({}), "
|
||||
"embedding_size ({}), and num_classes ({}).".format(
|
||||
pretrain_world_size, embedding_size, num_classes))
|
||||
logging.info("Parameters for inference or fine-tuning: "
|
||||
"world_size ({}).".format(self.world_size))
|
||||
|
||||
rank_str = '%05d' % self.rank
|
||||
|
||||
dist_weight_state_dict = rearrange_weight(
|
||||
dist_weight_state_dict, pretrain_world_size, self.world_size)
|
||||
dist_bias_state_dict = rearrange_weight(
|
||||
dist_bias_state_dict, pretrain_world_size, self.world_size)
|
||||
for name, value in dist_weight_state_dict.items():
|
||||
if rank_str in name:
|
||||
state_dict[name] = value
|
||||
for name, value in dist_bias_velocity_state_dict.items():
|
||||
for name, value in dist_bias_state_dict.items():
|
||||
if rank_str in name:
|
||||
state_dict[name] = value
|
||||
|
||||
if for_train:
|
||||
dist_weight_velocity_state_dict = rearrange_weight(
|
||||
dist_weight_velocity_state_dict, pretrain_world_size,
|
||||
self.world_size)
|
||||
dist_bias_velocity_state_dict = rearrange_weight(
|
||||
dist_bias_velocity_state_dict, pretrain_world_size,
|
||||
self.world_size)
|
||||
for name, value in dist_weight_velocity_state_dict.items():
|
||||
if rank_str in name:
|
||||
state_dict[name] = value
|
||||
for name, value in dist_bias_velocity_state_dict.items():
|
||||
if rank_str in name:
|
||||
state_dict[name] = value
|
||||
|
||||
program.set_state_dict(state_dict)
|
||||
logging.info("Load checkpoint from '{}'. ".format(checkpoint_dir))
|
||||
if for_train:
|
||||
|
||||
Reference in New Issue
Block a user