from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys import math import random import logging import pickle import numpy as np from data import FaceIter from data import FaceImageIter from data import FaceImageIter2 from data import FaceImageIter4 from data import FaceImageIter5 import mxnet as mx from mxnet import ndarray as nd import argparse import mxnet.optimizer as optimizer sys.path.append(os.path.join(os.path.dirname(__file__), 'common')) import face_image sys.path.append(os.path.join(os.path.dirname(__file__), 'eval')) sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols')) import fresnet import finception_resnet_v2 import fmobilenet import fxception import fdensenet #import lfw import verification import sklearn logger = logging.getLogger() logger.setLevel(logging.INFO) class AccMetric(mx.metric.EvalMetric): def __init__(self): self.axis = 1 super(AccMetric, self).__init__( 'acc', axis=self.axis, output_names=None, label_names=None) self.losses = [] def update(self, labels, preds): #loss = preds[2].asnumpy()[0] #if len(self.losses)==20: # print('ce loss', sum(self.losses)/len(self.losses)) # self.losses = [] #self.losses.append(loss) preds = [preds[1]] #use softmax output for label, pred_label in zip(labels, preds): #print(pred_label) #print(label.shape, pred_label.shape) if pred_label.shape != label.shape: pred_label = mx.ndarray.argmax(pred_label, axis=self.axis) pred_label = pred_label.asnumpy().astype('int32').flatten() label = label.asnumpy().astype('int32').flatten() #print(label) #print('label',label) #print('pred_label', pred_label) assert label.shape==pred_label.shape self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat) def parse_args(): parser = argparse.ArgumentParser(description='Train face network') # general parser.add_argument('--data-dir', default='', help='') parser.add_argument('--prefix', default='../model/model', help='directory to save model.') parser.add_argument('--pretrained', default='../model/resnet-152', help='') parser.add_argument('--network', default='s20', help='') parser.add_argument('--use-se', action='store_true', default=False, help='') parser.add_argument('--version-input', type=int, default=1, help='') parser.add_argument('--version-output', type=str, default='A', help='') parser.add_argument('--version-unit', type=int, default=1, help='') parser.add_argument('--load-epoch', type=int, default=0, help='load epoch.') parser.add_argument('--end-epoch', type=int, default=1000, help='training epoch size.') parser.add_argument('--retrain', action='store_true', default=False, help='true means continue training.') parser.add_argument('--lr', type=float, default=0.1, help='') parser.add_argument('--wd', type=float, default=0.0005, help='') parser.add_argument('--images-per-identity', type=int, default=16, help='') parser.add_argument('--embedding-dim', type=int, default=512, help='') parser.add_argument('--per-batch-size', type=int, default=0, help='') parser.add_argument('--margin', type=int, default=4, help='') parser.add_argument('--beta', type=float, default=1000., help='') parser.add_argument('--beta-min', type=float, default=5., help='') parser.add_argument('--beta-freeze', type=int, default=0, help='') parser.add_argument('--gamma', type=float, default=0.12, help='') parser.add_argument('--power', type=float, default=1.0, help='') parser.add_argument('--scale', type=float, default=0.9993, help='') parser.add_argument('--verbose', type=int, default=2000, help='') parser.add_argument('--loss-type', type=int, default=1, help='') parser.add_argument('--incay', type=float, default=0.0, help='feature incay') parser.add_argument('--use-deformable', type=int, default=0, help='') parser.add_argument('--patch', type=str, default='0_0_96_112_0', help='') parser.add_argument('--lr-steps', type=str, default='', help='') args = parser.parse_args() return args def get_symbol(args, arg_params, aux_params): if args.retrain: new_args = arg_params else: new_args = None data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) if args.network[0]=='d': embedding = fdensenet.get_symbol(512, args.num_layers, use_se=args.use_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) embedding = fmobilenet.get_symbol(512, use_se=args.use_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(512, use_se=args.use_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(512, use_se=args.use_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(512, args.num_layers, use_se=args.use_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) gt_label = mx.symbol.Variable('softmax_label') assert args.loss_type>=0 extra_loss = None if args.loss_type==0: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==10: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') params = [1.2, 0.3, 1.0] n1 = mx.sym.expand_dims(nembedding, axis=1) n2 = mx.sym.expand_dims(nembedding, axis=0) body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N #body = mx.sym.sqrt(body) body = body - params[0] mask = mx.sym.Variable('extra') body = body*mask body = body+params[1] #body = mx.sym.maximum(body, 0.0) body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2]) elif args.loss_type==11: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') params = [0.9, 0.2] nembedding = mx.symbol.slice_axis(embedding, axis=0, begin=0, end=args.images_per_identity) nembedding = mx.symbol.L2Normalization(nembedding, mode='instance', name='fc1n') n1 = mx.sym.expand_dims(nembedding, axis=1) n2 = mx.sym.expand_dims(nembedding, axis=0) body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N body = body - params[0] body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) n = args.images_per_identity body = body/(n*n-n) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[1]) #extra_loss = None else: #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type) embedding = embedding * 5 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2 fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=100, name='fc7') #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes, # beta=args.beta, margin=args.margin, scale=args.scale, # op_type='ASoftmax', name='fc7') softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') if args.loss_type<=1 and args.incay>0.0: params = [1.e-10] sel = mx.symbol.argmax(data = fc7, axis=1) sel = (sel==gt_label) norm = embedding*embedding norm = mx.symbol.sum(norm, axis=1) norm = norm+params[0] feature_incay = sel/norm feature_incay = mx.symbol.mean(feature_incay) * args.incay extra_loss = mx.symbol.MakeLoss(feature_incay) #out = softmax #l2_embedding = mx.symbol.L2Normalization(embedding) #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)]) if extra_loss is not None: out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, extra_loss]) else: out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax]) return (out, new_args, aux_params) def train_net(args): ctx = [] cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip() if len(cvd)>0: for i in xrange(len(cvd.split(','))): ctx.append(mx.gpu(i)) if len(ctx)==0: ctx = [mx.cpu()] print('use cpu') else: print('gpu num:', len(ctx)) prefix = args.prefix prefix_dir = os.path.dirname(prefix) if not os.path.exists(prefix_dir): os.makedirs(prefix_dir) end_epoch = args.end_epoch pretrained = args.pretrained load_epoch = args.load_epoch args.ctx_num = len(ctx) args.num_layers = int(args.network[1:]) print('num_layers', args.num_layers) if args.per_batch_size==0: args.per_batch_size = 128 if args.network[0]=='r': args.per_batch_size = 128 else: if args.num_layers>=64: args.per_batch_size = 120 if args.ctx_num==2: args.per_batch_size *= 2 elif args.ctx_num==3: args.per_batch_size = 170 if args.network[0]=='m': args.per_batch_size = 128 args.batch_size = args.per_batch_size*args.ctx_num args.rescale_threshold = 0 args.image_channel = 3 ppatch = [int(x) for x in args.patch.split('_')] assert len(ppatch)==5 os.environ['BETA'] = str(args.beta) args.use_val = False path_imgrec = None path_imglist = None val_rec = None prop = face_image.load_property(args.data_dir) args.num_classes = prop.num_classes image_size = prop.image_size args.image_h = image_size[0] args.image_w = image_size[1] print('image_size', image_size) assert(args.num_classes>0) print('num_classes', args.num_classes) #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2" path_imgrec = os.path.join(args.data_dir, "train.rec") val_rec = os.path.join(args.data_dir, "val.rec") if os.path.exists(val_rec): args.use_val = True else: val_rec = None #args.num_classes = 10572 #webface #args.num_classes = 81017 #args.num_classes = 82395 if args.loss_type==1 and args.num_classes>40000: args.beta_freeze = 5000 args.gamma = 0.06 print('Called with argument:', args) data_shape = (args.image_channel,image_size[0],image_size[1]) mean = None if args.use_val: val_dataiter = FaceImageIter( batch_size = args.batch_size, data_shape = data_shape, path_imgrec = val_rec, #path_imglist = val_path, shuffle = False, rand_mirror = False, mean = mean, ) else: val_dataiter = None begin_epoch = 0 base_lr = args.lr base_wd = args.wd base_mom = 0.9 if not args.retrain: arg_params = None aux_params = None sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) else: _, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch) sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) if args.loss_type!=10: model = mx.mod.Module( context = ctx, symbol = sym, ) else: data_names = ('data', 'extra') model = mx.mod.Module( context = ctx, symbol = sym, data_names = data_names, ) if args.loss_type<=9: train_dataiter = FaceImageIter( batch_size = args.batch_size, data_shape = data_shape, path_imgrec = path_imgrec, shuffle = True, rand_mirror = True, mean = mean, ) elif args.loss_type==10: train_dataiter = FaceImageIter4( batch_size = args.batch_size, ctx_num = args.ctx_num, images_per_identity = args.images_per_identity, data_shape = data_shape, path_imglist = path_imglist, shuffle = True, rand_mirror = True, mean = mean, patch = ppatch, use_extra = True, model = model, ) elif args.loss_type==11: train_dataiter = FaceImageIter5( batch_size = args.batch_size, ctx_num = args.ctx_num, images_per_identity = args.images_per_identity, data_shape = data_shape, path_imglist = path_imglist, shuffle = True, rand_mirror = True, mean = mean, patch = ppatch, ) _acc = AccMetric() eval_metrics = [mx.metric.create(_acc)] if args.network[0]=='r': initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style elif args.network[0]=='i' or args.network[0]=='x': initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception else: initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2) _rescale = 1.0/args.ctx_num opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale) _cb = mx.callback.Speedometer(args.batch_size, 20) ver_list = [] ver_name_list = [] for name in ['lfw','cfp_ff','cfp_fp','agedb_30']: path = os.path.join(args.data_dir,name+".bin") if os.path.exists(path): data_set = verification.load_bin(path, image_size) ver_list.append(data_set) ver_name_list.append(name) print('ver', name) def ver_test(nbatch): results = [] for i in xrange(len(ver_list)): acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size) #print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm)) #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1)) print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2)) results.append(acc2) return results def val_test(): acc = AccMetric() val_metric = mx.metric.create(acc) val_metric.reset() val_dataiter.reset() for i, eval_batch in enumerate(val_dataiter): model.forward(eval_batch, is_train=False) model.update_metric(val_metric, eval_batch.label) acc_value = val_metric.get_name_value()[0][1] print('VACC: %f'%(acc_value)) highest_acc = [] for i in xrange(len(ver_list)): highest_acc.append(0.0) global_step = [0] save_step = [0] if len(args.lr_steps)==0: lr_steps = [40000, 60000, 80000] if args.loss_type==1: lr_steps = [80000, 120000, 140000] p = 512.0/args.batch_size for l in xrange(len(lr_steps)): lr_steps[l] = int(lr_steps[l]*p) else: lr_steps = [int(x) for x in args.lr_steps.split(',')] print('lr_steps', lr_steps) def _batch_callback(param): #global global_step global_step[0]+=1 mbatch = global_step[0] for _lr in lr_steps: if mbatch==args.beta_freeze+_lr: opt.lr *= 0.1 print('lr change to', opt.lr) break _cb(param) if mbatch%1000==0: print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch) if mbatch>=0 and mbatch%args.verbose==0: acc_list = ver_test(mbatch) save_step[0]+=1 msave = save_step[0] do_save = False lfw_score = acc_list[0] for i in xrange(len(acc_list)): acc = acc_list[i] if acc>=highest_acc[i]: highest_acc[i] = acc if lfw_score>=0.99: do_save = True if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0: do_save = True if do_save: print('saving', msave, acc) if val_dataiter is not None: val_test() arg, aux = model.get_params() mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux) #if acc>=highest_acc[0]: # lfw_npy = "%s-lfw-%04d" % (prefix, msave) # X = np.concatenate(embeddings_list, axis=0) # print('saving lfw npy', X.shape) # np.save(lfw_npy, X) #print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[0])) if mbatch<=args.beta_freeze: _beta = args.beta else: move = max(0, mbatch-args.beta_freeze) _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power)) #print('beta', _beta) os.environ['BETA'] = str(_beta) #epoch_cb = mx.callback.do_checkpoint(prefix, 1) epoch_cb = None #def _epoch_callback(epoch, sym, arg_params, aux_params): # print('epoch-end', epoch) model.fit(train_dataiter, begin_epoch = begin_epoch, num_epoch = end_epoch, eval_data = val_dataiter, eval_metric = eval_metrics, kvstore = 'device', optimizer = opt, #optimizer_params = optimizer_params, initializer = initializer, arg_params = arg_params, aux_params = aux_params, allow_missing = True, batch_end_callback = _batch_callback, epoch_end_callback = epoch_cb ) def main(): #time.sleep(3600*6.5) args = parse_args() train_net(args) if __name__ == '__main__': main()