From 0c6ba502939fd80ad2bf0fb577a29d7ffaa78a76 Mon Sep 17 00:00:00 2001 From: GuoxiaWang Date: Fri, 15 Oct 2021 23:18:22 +0800 Subject: [PATCH] fix bug --- .../arcface_paddle/dynamic/utils/io.py | 93 +++++++++--------- recognition/arcface_paddle/static/utils/io.py | 95 ++++++++++--------- 2 files changed, 95 insertions(+), 93 deletions(-) diff --git a/recognition/arcface_paddle/dynamic/utils/io.py b/recognition/arcface_paddle/dynamic/utils/io.py index 866b0c0..a30449c 100644 --- a/recognition/arcface_paddle/dynamic/utils/io.py +++ b/recognition/arcface_paddle/dynamic/utils/io.py @@ -161,58 +161,59 @@ class Checkpoint(object): else: opt_state_dict[name] = tensor - meta_file = os.path.join(checkpoint_dir, 'meta.json') - if not os.path.exists(meta_file): - logging.error( - "Please make sure the checkpoint dir {} exists, and " - "parameters in that dir are validating.".format( - checkpoint_dir)) - exit() + if classifier is not None and for_train: + meta_file = os.path.join(checkpoint_dir, 'meta.json') + if not os.path.exists(meta_file): + logging.error( + "Please make sure the checkpoint dir {} exists, and " + "parameters in that dir are validating.".format( + checkpoint_dir)) + exit() - with open(meta_file, 'r') as handle: - extra_info = json.load(handle) + with open(meta_file, 'r') as handle: + extra_info = json.load(handle) - # Preporcess distributed parameters. - pretrain_world_size = extra_info['pretrain_world_size'] - assert pretrain_world_size > 0 - embedding_size = extra_info['embedding_size'] - assert embedding_size == self.embedding_size - num_classes = extra_info['num_classes'] - assert num_classes == self.num_classes + # Preporcess distributed parameters. + pretrain_world_size = extra_info['pretrain_world_size'] + assert pretrain_world_size > 0 + embedding_size = extra_info['embedding_size'] + assert embedding_size == self.embedding_size + num_classes = extra_info['num_classes'] + assert num_classes == self.num_classes - logging.info( - "Parameters for pre-training: pretrain_world_size ({}), " - "embedding_size ({}), and num_classes ({}).".format( - pretrain_world_size, embedding_size, num_classes)) - logging.info("Parameters for inference or fine-tuning: " - "world_size ({}).".format(self.world_size)) + logging.info( + "Parameters for pre-training: pretrain_world_size ({}), " + "embedding_size ({}), and num_classes ({}).".format( + pretrain_world_size, embedding_size, num_classes)) + logging.info("Parameters for inference or fine-tuning: " + "world_size ({}).".format(self.world_size)) - rank_str = '%05d' % self.rank + rank_str = '%05d' % self.rank - dist_weight_state_dict = rearrange_weight( - dist_weight_state_dict, pretrain_world_size, self.world_size) - dist_bias_state_dict = rearrange_weight( - dist_bias_state_dict, pretrain_world_size, self.world_size) - for name, value in dist_weight_state_dict.items(): - if rank_str in name: - dist_param_state_dict[name] = value - for name, value in dist_bias_state_dict.items(): - if rank_str in name: - dist_param_state_dict[name] = value - - if for_train: - dist_weight_velocity_state_dict = rearrange_weight( - dist_weight_velocity_state_dict, pretrain_world_size, - self.world_size) - dist_bias_velocity_state_dict = rearrange_weight( - dist_bias_velocity_state_dict, pretrain_world_size, - self.world_size) - for name, value in dist_weight_velocity_state_dict.items(): + dist_weight_state_dict = rearrange_weight( + dist_weight_state_dict, pretrain_world_size, self.world_size) + dist_bias_state_dict = rearrange_weight( + dist_bias_state_dict, pretrain_world_size, self.world_size) + for name, value in dist_weight_state_dict.items(): if rank_str in name: - opt_state_dict[name] = value - for name, value in dist_bias_velocity_state_dict.items(): + dist_param_state_dict[name] = value + for name, value in dist_bias_state_dict.items(): if rank_str in name: - opt_state_dict[name] = value + dist_param_state_dict[name] = value + + if for_train: + dist_weight_velocity_state_dict = rearrange_weight( + dist_weight_velocity_state_dict, pretrain_world_size, + self.world_size) + dist_bias_velocity_state_dict = rearrange_weight( + dist_bias_velocity_state_dict, pretrain_world_size, + self.world_size) + for name, value in dist_weight_velocity_state_dict.items(): + if rank_str in name: + opt_state_dict[name] = value + for name, value in dist_bias_velocity_state_dict.items(): + if rank_str in name: + opt_state_dict[name] = value def map_actual_param_name(state_dict, load_state_dict): for name, param in state_dict.items(): @@ -231,7 +232,7 @@ class Checkpoint(object): assert optimizer is not None optimizer.set_state_dict(opt_state_dict) - if for_train: + if classifier is not None and for_train: return extra_info else: return {} diff --git a/recognition/arcface_paddle/static/utils/io.py b/recognition/arcface_paddle/static/utils/io.py index 1cef6f0..acc98ce 100644 --- a/recognition/arcface_paddle/static/utils/io.py +++ b/recognition/arcface_paddle/static/utils/io.py @@ -135,59 +135,60 @@ class Checkpoint(object): else: state_dict[name] = tensor - meta_file = os.path.join(checkpoint_dir, 'meta.json') - if not os.path.exists(meta_file): - logging.error( - "Please make sure the checkpoint dir {} exists, and " - "parameters in that dir are validating.".format( - checkpoint_dir)) - exit() - - with open(meta_file, 'r') as handle: - extra_info = json.load(handle) - - # Preporcess distributed parameters. - pretrain_world_size = extra_info['pretrain_world_size'] - assert pretrain_world_size > 0 - embedding_size = extra_info['embedding_size'] - assert embedding_size == self.embedding_size - num_classes = extra_info['num_classes'] - assert num_classes == self.num_classes - - logging.info( - "Parameters for pre-training: pretrain_world_size ({}), " - "embedding_size ({}), and num_classes ({}).".format( - pretrain_world_size, embedding_size, num_classes)) - logging.info("Parameters for inference or fine-tuning: " - "world_size ({}).".format(self.world_size)) - - rank_str = '%05d' % self.rank - - dist_weight_state_dict = rearrange_weight( - dist_weight_state_dict, pretrain_world_size, self.world_size) - dist_bias_state_dict = rearrange_weight( - dist_bias_state_dict, pretrain_world_size, self.world_size) - for name, value in dist_weight_state_dict.items(): - if rank_str in name: - state_dict[name] = value - for name, value in dist_bias_state_dict.items(): - if rank_str in name: - state_dict[name] = value - if for_train: - dist_weight_velocity_state_dict = rearrange_weight( - dist_weight_velocity_state_dict, pretrain_world_size, - self.world_size) - dist_bias_velocity_state_dict = rearrange_weight( - dist_bias_velocity_state_dict, pretrain_world_size, - self.world_size) - for name, value in dist_weight_velocity_state_dict.items(): + meta_file = os.path.join(checkpoint_dir, 'meta.json') + if not os.path.exists(meta_file): + logging.error( + "Please make sure the checkpoint dir {} exists, and " + "parameters in that dir are validating.".format( + checkpoint_dir)) + exit() + + with open(meta_file, 'r') as handle: + extra_info = json.load(handle) + + # Preporcess distributed parameters. + pretrain_world_size = extra_info['pretrain_world_size'] + assert pretrain_world_size > 0 + embedding_size = extra_info['embedding_size'] + assert embedding_size == self.embedding_size + num_classes = extra_info['num_classes'] + assert num_classes == self.num_classes + + logging.info( + "Parameters for pre-training: pretrain_world_size ({}), " + "embedding_size ({}), and num_classes ({}).".format( + pretrain_world_size, embedding_size, num_classes)) + logging.info("Parameters for inference or fine-tuning: " + "world_size ({}).".format(self.world_size)) + + rank_str = '%05d' % self.rank + + dist_weight_state_dict = rearrange_weight( + dist_weight_state_dict, pretrain_world_size, self.world_size) + dist_bias_state_dict = rearrange_weight( + dist_bias_state_dict, pretrain_world_size, self.world_size) + for name, value in dist_weight_state_dict.items(): if rank_str in name: state_dict[name] = value - for name, value in dist_bias_velocity_state_dict.items(): + for name, value in dist_bias_state_dict.items(): if rank_str in name: state_dict[name] = value + if for_train: + dist_weight_velocity_state_dict = rearrange_weight( + dist_weight_velocity_state_dict, pretrain_world_size, + self.world_size) + dist_bias_velocity_state_dict = rearrange_weight( + dist_bias_velocity_state_dict, pretrain_world_size, + self.world_size) + for name, value in dist_weight_velocity_state_dict.items(): + if rank_str in name: + state_dict[name] = value + for name, value in dist_bias_velocity_state_dict.items(): + if rank_str in name: + state_dict[name] = value + program.set_state_dict(state_dict) logging.info("Load checkpoint from '{}'. ".format(checkpoint_dir)) if for_train: