From 0cf7ebadad303687cc178a1bc80668dfd0bb6289 Mon Sep 17 00:00:00 2001 From: nttstar Date: Thu, 17 Jan 2019 16:42:39 +0800 Subject: [PATCH] move more to config --- recognition/sample_config.py | 17 ++++++++----- recognition/train.py | 47 +++++++++++++++--------------------- 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/recognition/sample_config.py b/recognition/sample_config.py index 47a63fc..3e55a5e 100644 --- a/recognition/sample_config.py +++ b/recognition/sample_config.py @@ -13,6 +13,15 @@ config.net_input = 1 config.net_output = 'E' config.net_multiplier = 1.0 config.val_targets = ['lfw', 'cfp_fp', 'agedb_30'] +config.ce_loss = False +config.fc7_lr_mult = 1.0 +config.fc7_wd_mult = 1.0 +config.fc7_no_bias = False +config.max_steps = 0 +config.data_rand_mirror = True +config.data_cutoff = False +config.data_color = 0 +config.data_images_filter = 0 # network settings @@ -51,7 +60,7 @@ dataset = edict() dataset.emore = edict() dataset.emore.dataset = 'emore' -dataset.emore.dataset_path = './faces_emore' +dataset.emore.dataset_path = '../datasets/faces_emore' dataset.emore.num_classes = 85742 dataset.emore.image_shape = (112,112,3) dataset.emore.val_targets = ['lfw', 'cfp_fp', 'agedb_30'] @@ -59,10 +68,6 @@ dataset.emore.val_targets = ['lfw', 'cfp_fp', 'agedb_30'] loss = edict() loss.softmax = edict() loss.softmax.loss_name = 'softmax' -loss.softmax.loss_s = -1.0 -loss.softmax.loss_m1 = 0.0 -loss.softmax.loss_m2 = 0.0 -loss.softmax.loss_m3 = 0.0 loss.nsoftmax = edict() loss.nsoftmax.loss_name = 'margin_softmax' @@ -116,7 +121,7 @@ default = edict() # default network default.network = 'r100' default.pretrained = '' -default.pretrained_epoch = 0 +default.pretrained_epoch = 1 # default dataset default.dataset = 'emore' default.loss = 'arcface' diff --git a/recognition/train.py b/recognition/train.py index 3f106c7..84c6693 100644 --- a/recognition/train.py +++ b/recognition/train.py @@ -78,25 +78,16 @@ def parse_args(): args, rest = parser.parse_known_args() generate_config(args.network, args.dataset, args.loss) parser.add_argument('--models-root', default=default.models_root, help='root directory to save model.') - parser.add_argument('--pretrained', default='', help='pretrained model to load') + parser.add_argument('--pretrained', default=default.pretrained, help='pretrained model to load') + parser.add_argument('--pretrained-epoch', default=default.pretrained_epoch, help='pretrained epoch to load') parser.add_argument('--ckpt', type=int, default=default.ckpt, help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save') parser.add_argument('--verbose', type=int, default=default.verbose, help='do verification testing and model saving every verbose batches') - parser.add_argument('--max-steps', type=int, default=0, help='max training batches') - parser.add_argument('--end-epoch', type=int, default=100000, help='training epoch size.') parser.add_argument('--lr', type=float, default=default.lr, help='start learning rate') parser.add_argument('--lr-steps', type=str, default=default.lr_steps, help='steps of lr changing') parser.add_argument('--wd', type=float, default=default.wd, help='weight decay') parser.add_argument('--mom', type=float, default=default.mom, help='momentum') parser.add_argument('--frequent', type=int, default=default.frequent, help='') - parser.add_argument('--fc7-wd-mult', type=float, default=1.0, help='weight decay mult for fc7') - parser.add_argument('--fc7-lr-mult', type=float, default=1.0, help='lr mult for fc7') - parser.add_argument("--fc7-no-bias", default=False, action="store_true" , help="fc7 no bias flag") parser.add_argument('--per-batch-size', type=int, default=default.per_batch_size, help='batch size in each context') - parser.add_argument('--rand-mirror', type=int, default=1, help='if do random mirror in training') - parser.add_argument('--cutoff', type=int, default=0, help='cut off aug') - parser.add_argument('--color', type=int, default=0, help='color jittering aug') - parser.add_argument('--images-filter', type=int, default=0, help='minimum images per identity filter') - parser.add_argument('--ce-loss', default=False, action='store_true', help='if output ce loss') args = parser.parse_args() return args @@ -107,14 +98,16 @@ def get_symbol(args): gt_label = all_label is_softmax = True if config.loss_name=='softmax': #softmax - _weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) - if args.fc7_no_bias: + _weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), + lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01)) + if config.fc7_no_bias: fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=config.num_classes, name='fc7') else: _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=config.num_classes, name='fc7') elif config.loss_name=='margin_softmax': - _weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) + _weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), + lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01)) s = config.loss_s _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s @@ -170,7 +163,7 @@ def get_symbol(args): if is_softmax: softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) - if args.ce_loss: + if config.ce_loss: #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size body = mx.symbol.SoftmaxActivation(data=fc7) body = mx.symbol.log(body) @@ -200,7 +193,6 @@ def train_net(args): print('prefix', prefix) if not os.path.exists(prefix_dir): os.makedirs(prefix_dir) - end_epoch = args.end_epoch args.ctx_num = len(ctx) args.batch_size = args.per_batch_size*args.ctx_num args.rescale_threshold = 0 @@ -229,9 +221,8 @@ def train_net(args): data_shape_dict = {'data' : (args.per_batch_size,)+data_shape} spherenet.init_weights(sym, data_shape_dict, args.num_layers) else: - vec = args.pretrained.split(',') - print('loading', vec) - _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1])) + print('loading', args.pretrained, args.pretrained_epoch) + _, arg_params, aux_params = mx.model.load_checkpoint(args.pretrained, args.pretrained_epoch) sym = get_symbol(args) #label_name = 'softmax_label' @@ -250,9 +241,9 @@ def train_net(args): data_shape = data_shape, path_imgrec = path_imgrec, shuffle = True, - rand_mirror = args.rand_mirror, + rand_mirror = config.data_rand_mirror, mean = mean, - cutoff = args.cutoff, + cutoff = config.data_cutoff, ctx_num = args.ctx_num, images_per_identity = config.images_per_identity, triplet_params = triplet_params, @@ -267,15 +258,15 @@ def train_net(args): data_shape = data_shape, path_imgrec = path_imgrec, shuffle = True, - rand_mirror = args.rand_mirror, + rand_mirror = config.data_rand_mirror, mean = mean, - cutoff = args.cutoff, - color_jittering = args.color, - images_filter = args.images_filter, + cutoff = config.data_cutoff, + color_jittering = config.data_color, + images_filter = config.data_images_filter, ) metric1 = AccMetric() eval_metrics = [mx.metric.create(metric1)] - if args.ce_loss: + if config.ce_loss: metric2 = LossValueMetric() eval_metrics.append( mx.metric.create(metric2) ) @@ -370,7 +361,7 @@ def train_net(args): arg, aux = model.get_params() mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux) print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1])) - if args.max_steps>0 and mbatch>args.max_steps: + if config.max_steps>0 and mbatch>config.max_steps: sys.exit(0) epoch_cb = None @@ -378,7 +369,7 @@ def train_net(args): model.fit(train_dataiter, begin_epoch = begin_epoch, - num_epoch = end_epoch, + num_epoch = 999999, eval_data = val_dataiter, eval_metric = eval_metrics, kvstore = 'device',