from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys import math import random import logging import sklearn import pickle import numpy as np import mxnet as mx from mxnet import ndarray as nd import argparse import mxnet.optimizer as optimizer from config import config, default, generate_config from metric import * sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common')) import flops_counter sys.path.append(os.path.join(os.path.dirname(__file__), 'eval')) import verification sys.path.append(os.path.join(os.path.dirname(__file__), 'symbol')) import fresnet import fmobilefacenet import fmobilenet import fmnasnet import fdensenet logger = logging.getLogger() logger.setLevel(logging.INFO) args = None def parse_args(): parser = argparse.ArgumentParser(description='Train face network') # general parser.add_argument('--dataset', default=default.dataset, help='dataset config') parser.add_argument('--network', default=default.network, help='network config') parser.add_argument('--loss', default=default.loss, help='loss config') args, rest = parser.parse_known_args() generate_config(args.network, args.dataset, args.loss) parser.add_argument('--models-root', default=default.models_root, help='root directory to save model.') parser.add_argument('--pretrained', default=default.pretrained, help='pretrained model to load') parser.add_argument('--pretrained-epoch', type=int, default=default.pretrained_epoch, help='pretrained epoch to load') parser.add_argument('--ckpt', type=int, default=default.ckpt, help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save') parser.add_argument('--verbose', type=int, default=default.verbose, help='do verification testing and model saving every verbose batches') parser.add_argument('--lr', type=float, default=default.lr, help='start learning rate') parser.add_argument('--lr-steps', type=str, default=default.lr_steps, help='steps of lr changing') parser.add_argument('--wd', type=float, default=default.wd, help='weight decay') parser.add_argument('--mom', type=float, default=default.mom, help='momentum') parser.add_argument('--frequent', type=int, default=default.frequent, help='') parser.add_argument('--per-batch-size', type=int, default=default.per_batch_size, help='batch size in each context') parser.add_argument('--kvstore', type=str, default=default.kvstore, help='kvstore setting') args = parser.parse_args() return args def get_symbol(args): embedding = eval(config.net_name).get_symbol() all_label = mx.symbol.Variable('softmax_label') gt_label = all_label is_softmax = True if config.loss_name=='softmax': #softmax _weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01)) if config.fc7_no_bias: fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=config.num_classes, name='fc7') else: _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=config.num_classes, name='fc7') elif config.loss_name=='margin_softmax': _weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01)) s = config.loss_s _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=config.num_classes, name='fc7') if config.loss_m1!=1.0 or config.loss_m2!=0.0 or config.loss_m3!=0.0: if config.loss_m1==1.0 and config.loss_m2==0.0: s_m = s*config.loss_m3 gt_one_hot = mx.sym.one_hot(gt_label, depth = config.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if config.loss_m1!=1.0: t = t*config.loss_m1 if config.loss_m2>0.0: t = t+config.loss_m2 body = mx.sym.cos(t) if config.loss_m3>0.0: body = body - config.loss_m3 new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = config.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif config.loss_name.find('triplet')>=0: is_softmax = False nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size) if config.loss_name=='triplet': ap = anchor - positive an = anchor - negative ap = ap*ap an = an*an ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) triplet_loss = mx.symbol.Activation(data = (ap-an+config.triplet_alpha), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) else: ap = anchor*positive an = anchor*negative ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) ap = mx.sym.arccos(ap) an = mx.sym.arccos(an) triplet_loss = mx.symbol.Activation(data = (ap-an+config.triplet_alpha), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) triplet_loss = mx.symbol.MakeLoss(triplet_loss) out_list = [mx.symbol.BlockGrad(embedding)] if is_softmax: softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) if config.ce_loss: #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size body = mx.symbol.SoftmaxActivation(data=fc7) body = mx.symbol.log(body) _label = mx.sym.one_hot(gt_label, depth = config.num_classes, on_value = -1.0, off_value = 0.0) body = body*_label ce_loss = mx.symbol.sum(body)/args.per_batch_size out_list.append(mx.symbol.BlockGrad(ce_loss)) else: out_list.append(mx.sym.BlockGrad(gt_label)) out_list.append(triplet_loss) out = mx.symbol.Group(out_list) return out def train_net(args): ctx = [] cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip() if len(cvd)>0: for i in range(len(cvd.split(','))): ctx.append(mx.gpu(i)) if len(ctx)==0: ctx = [mx.cpu()] print('use cpu') else: print('gpu num:', len(ctx)) prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model') prefix_dir = os.path.dirname(prefix) print('prefix', prefix) if not os.path.exists(prefix_dir): os.makedirs(prefix_dir) args.ctx_num = len(ctx) args.batch_size = args.per_batch_size*args.ctx_num args.rescale_threshold = 0 args.image_channel = config.image_shape[2] config.batch_size = args.batch_size config.per_batch_size = args.per_batch_size data_dir = config.dataset_path path_imgrec = None path_imglist = None image_size = config.image_shape[0:2] assert len(image_size)==2 assert image_size[0]==image_size[1] print('image_size', image_size) print('num_classes', config.num_classes) path_imgrec = os.path.join(data_dir, "train.rec") print('Called with argument:', args, config) data_shape = (args.image_channel,image_size[0],image_size[1]) mean = None begin_epoch = 0 if len(args.pretrained)==0: arg_params = None aux_params = None sym = get_symbol(args) if config.net_name=='spherenet': data_shape_dict = {'data' : (args.per_batch_size,)+data_shape} spherenet.init_weights(sym, data_shape_dict, args.num_layers) else: print('loading', args.pretrained, args.pretrained_epoch) _, arg_params, aux_params = mx.model.load_checkpoint(args.pretrained, args.pretrained_epoch) sym = get_symbol(args) if config.count_flops: all_layers = sym.get_internals() _sym = all_layers['fc1_output'] FLOPs = flops_counter.count_flops(_sym, data=(1,3,image_size[0],image_size[1])) _str = flops_counter.flops_str(FLOPs) print('Network FLOPs: %s'%_str) #label_name = 'softmax_label' #label_shape = (args.batch_size,) model = mx.mod.Module( context = ctx, symbol = sym, ) val_dataiter = None if config.loss_name.find('triplet')>=0: from triplet_image_iter import FaceImageIter triplet_params = [config.triplet_bag_size, config.triplet_alpha, config.triplet_max_ap] train_dataiter = FaceImageIter( batch_size = args.batch_size, data_shape = data_shape, path_imgrec = path_imgrec, shuffle = True, rand_mirror = config.data_rand_mirror, mean = mean, cutoff = config.data_cutoff, ctx_num = args.ctx_num, images_per_identity = config.images_per_identity, triplet_params = triplet_params, mx_model = model, ) _metric = LossValueMetric() eval_metrics = [mx.metric.create(_metric)] else: from image_iter import FaceImageIter train_dataiter = FaceImageIter( batch_size = args.batch_size, data_shape = data_shape, path_imgrec = path_imgrec, shuffle = True, rand_mirror = config.data_rand_mirror, mean = mean, cutoff = config.data_cutoff, color_jittering = config.data_color, images_filter = config.data_images_filter, ) metric1 = AccMetric() eval_metrics = [mx.metric.create(metric1)] if config.ce_loss: metric2 = LossValueMetric() eval_metrics.append( mx.metric.create(metric2) ) if config.net_name=='fresnet' or config.net_name=='fmobilefacenet': initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style else: initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2) #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style _rescale = 1.0/args.ctx_num opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale) _cb = mx.callback.Speedometer(args.batch_size, args.frequent) ver_list = [] ver_name_list = [] for name in config.val_targets: path = os.path.join(data_dir,name+".bin") if os.path.exists(path): data_set = verification.load_bin(path, image_size) ver_list.append(data_set) ver_name_list.append(name) print('ver', name) def ver_test(nbatch): results = [] for i in range(len(ver_list)): acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None) print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm)) #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1)) print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2)) results.append(acc2) return results highest_acc = [0.0, 0.0] #lfw and target #for i in range(len(ver_list)): # highest_acc.append(0.0) global_step = [0] save_step = [0] lr_steps = [int(x) for x in args.lr_steps.split(',')] print('lr_steps', lr_steps) def _batch_callback(param): #global global_step global_step[0]+=1 mbatch = global_step[0] for step in lr_steps: if mbatch==step: opt.lr *= 0.1 print('lr change to', opt.lr) break _cb(param) if mbatch%1000==0: print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch) if mbatch>=0 and mbatch%args.verbose==0: acc_list = ver_test(mbatch) save_step[0]+=1 msave = save_step[0] do_save = False is_highest = False if len(acc_list)>0: #lfw_score = acc_list[0] #if lfw_score>highest_acc[0]: # highest_acc[0] = lfw_score # if lfw_score>=0.998: # do_save = True score = sum(acc_list) if acc_list[-1]>=highest_acc[-1]: if acc_list[-1]>highest_acc[-1]: is_highest = True else: if score>=highest_acc[0]: is_highest = True highest_acc[0] = score highest_acc[-1] = acc_list[-1] #if lfw_score>=0.99: # do_save = True if is_highest: do_save = True if args.ckpt==0: do_save = False elif args.ckpt==2: do_save = True elif args.ckpt==3: msave = 1 if do_save: print('saving', msave) arg, aux = model.get_params() if config.ckpt_embedding: all_layers = model.symbol.get_internals() _sym = all_layers['fc1_output'] _arg = {} for k in arg: if not k.startswith('fc7'): _arg[k] = arg[k] mx.model.save_checkpoint(prefix, msave, _sym, _arg, aux) else: mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux) print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1])) if config.max_steps>0 and mbatch>config.max_steps: sys.exit(0) epoch_cb = None train_dataiter = mx.io.PrefetchingIter(train_dataiter) model.fit(train_dataiter, begin_epoch = begin_epoch, num_epoch = 999999, eval_data = val_dataiter, eval_metric = eval_metrics, kvstore = args.kvstore, optimizer = opt, #optimizer_params = optimizer_params, initializer = initializer, arg_params = arg_params, aux_params = aux_params, allow_missing = True, batch_end_callback = _batch_callback, epoch_end_callback = epoch_cb ) def main(): global args args = parse_args() train_net(args) if __name__ == '__main__': main()