# THIS FILE IS FOR EXPERIMENTS, USE train_softmax.py FOR NORMAL TRAINING. from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys import math import random import logging import pickle import numpy as np from data import FaceImageIter from data import FaceImageIterList import mxnet as mx from mxnet import ndarray as nd import argparse import mxnet.optimizer as optimizer sys.path.append(os.path.join(os.path.dirname(__file__), 'common')) import face_image from noise_sgd import NoiseSGD sys.path.append(os.path.join(os.path.dirname(__file__), 'eval')) sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols')) import fresnet import finception_resnet_v2 import fmobilenet import fmobilenetv2 import fxception import fdensenet import fdpn import fnasnet import spherenet #import lfw import verification import sklearn sys.path.append(os.path.join(os.path.dirname(__file__), 'losses')) import center_loss logger = logging.getLogger() logger.setLevel(logging.INFO) args = None class AccMetric(mx.metric.EvalMetric): def __init__(self): self.axis = 1 super(AccMetric, self).__init__( 'acc', axis=self.axis, output_names=None, label_names=None) self.losses = [] self.count = 0 def update(self, labels, preds): self.count+=1 if args.loss_type>=2 and args.loss_type<=7 and args.margin_verbose>0: if self.count%args.ctx_num==0: mbatch = self.count//args.ctx_num _verbose = args.margin_verbose if mbatch==1 or mbatch%_verbose==0: a = 0.0 b = 0.0 if len(preds)>=4: a = preds[-2].asnumpy()[0] b = preds[-1].asnumpy()[0] elif len(preds)==3: a = preds[-1].asnumpy()[0] b = a print('[%d][MARGIN]%f,%f'%(mbatch,a,b)) if args.logits_verbose>0: if self.count%args.ctx_num==0: mbatch = self.count//args.ctx_num _verbose = args.logits_verbose if mbatch==1 or mbatch%_verbose==0: a = 0.0 b = 0.0 if len(preds)>=3: v = preds[-1].asnumpy() v = np.sort(v) num = len(v)//10 a = np.mean(v[0:num]) b = np.mean(v[-1*num:]) print('[LOGITS] %d,%f,%f'%(mbatch,a,b)) #loss = preds[2].asnumpy()[0] #if len(self.losses)==20: # print('ce loss', sum(self.losses)/len(self.losses)) # self.losses = [] #self.losses.append(loss) preds = [preds[1]] #use softmax output for label, pred_label in zip(labels, preds): #print(pred_label) #print(label.shape, pred_label.shape) if pred_label.shape != label.shape: pred_label = mx.ndarray.argmax(pred_label, axis=self.axis) pred_label = pred_label.asnumpy().astype('int32').flatten() label = label.asnumpy() if label.ndim==2: label = label[:,0] label = label.astype('int32').flatten() #print(label) #print('label',label) #print('pred_label', pred_label) assert label.shape==pred_label.shape self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat) class LossValueMetric(mx.metric.EvalMetric): def __init__(self): self.axis = 1 super(LossValueMetric, self).__init__( 'lossvalue', axis=self.axis, output_names=None, label_names=None) self.losses = [] def update(self, labels, preds): loss = preds[-1].asnumpy()[0] self.sum_metric += loss self.num_inst += 1.0 gt_label = preds[-2].asnumpy() #print(gt_label) def parse_args(): parser = argparse.ArgumentParser(description='Train face network') # general parser.add_argument('--data-dir', default='', help='training set directory') parser.add_argument('--prefix', default='../model/model', help='directory to save model.') parser.add_argument('--pretrained', default='', help='pretrained model to load') parser.add_argument('--ckpt', type=int, default=1, help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save') parser.add_argument('--network', default='r50', help='specify network') parser.add_argument('--version-se', type=int, default=0, help='whether to use se in network') parser.add_argument('--version-input', type=int, default=1, help='network input config') parser.add_argument('--version-output', type=str, default='E', help='network embedding output config') parser.add_argument('--version-unit', type=int, default=3, help='resnet unit config') parser.add_argument('--version-act', type=str, default='prelu', help='network activation config') parser.add_argument('--end-epoch', type=int, default=100000, help='training epoch size.') parser.add_argument('--noise-sgd', type=float, default=0.0, help='') parser.add_argument('--lr', type=float, default=0.1, help='start learning rate') parser.add_argument('--wd', type=float, default=0.0005, help='weight decay') parser.add_argument('--mom', type=float, default=0.9, help='momentum') parser.add_argument('--emb-size', type=int, default=512, help='embedding length') parser.add_argument('--per-batch-size', type=int, default=128, help='batch size in each context') parser.add_argument('--margin-m', type=float, default=0.5, help='') parser.add_argument('--margin-s', type=float, default=64.0, help='') parser.add_argument('--margin-a', type=float, default=0.0, help='') parser.add_argument('--margin-b', type=float, default=0.0, help='') parser.add_argument('--easy-margin', type=int, default=0, help='') parser.add_argument('--margin-verbose', type=int, default=0, help='') parser.add_argument('--logits-verbose', type=int, default=0, help='') parser.add_argument('--c2c-threshold', type=float, default=0.0, help='') parser.add_argument('--c2c-mode', type=int, default=-10, help='') parser.add_argument('--output-c2c', type=int, default=0, help='') parser.add_argument('--train-limit', type=int, default=0, help='') parser.add_argument('--margin', type=int, default=4, help='') parser.add_argument('--beta', type=float, default=1000., help='') parser.add_argument('--beta-min', type=float, default=5., help='') parser.add_argument('--beta-freeze', type=int, default=0, help='') parser.add_argument('--gamma', type=float, default=0.12, help='') parser.add_argument('--power', type=float, default=1.0, help='') parser.add_argument('--scale', type=float, default=0.9993, help='') parser.add_argument('--center-alpha', type=float, default=0.5, help='') parser.add_argument('--center-scale', type=float, default=0.003, help='') parser.add_argument('--images-per-identity', type=int, default=0, help='') parser.add_argument('--triplet-bag-size', type=int, default=3600, help='') parser.add_argument('--triplet-alpha', type=float, default=0.3, help='') parser.add_argument('--triplet-max-ap', type=float, default=0.0, help='') parser.add_argument('--verbose', type=int, default=2000, help='') parser.add_argument('--loss-type', type=int, default=4, help='') parser.add_argument('--incay', type=float, default=0.0, help='feature incay') parser.add_argument('--use-deformable', type=int, default=0, help='') parser.add_argument('--rand-mirror', type=int, default=1, help='') parser.add_argument('--cutoff', type=int, default=0, help='') parser.add_argument('--patch', type=str, default='0_0_96_112_0',help='') parser.add_argument('--lr-steps', type=str, default='', help='') parser.add_argument('--max-steps', type=int, default=0, help='') parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30,cplfw,calfw', help='') args = parser.parse_args() return args def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') if not args.output_c2c: gt_label = all_label else: gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1) gt_label = mx.symbol.reshape(gt_label, (args.per_batch_size,)) c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2) c2c_label = mx.symbol.reshape(c2c_label, (args.per_batch_size,)) assert args.loss_type>=0 extra_loss = None if args.loss_type==0: #softmax _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: #sphere _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==8: #centerloss, TODO _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') print('center-loss', args.center_alpha, args.center_scale) extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\ num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size) elif args.loss_type==2: s = args.margin_s m = args.margin_m _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') if s>0.0: nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if m>0.0: if args.margin_verbose>0: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s margin_symbols.append(mx.symbol.mean(cos_t)) s_m = s*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot if args.margin_verbose>0: new_zy = mx.sym.pick(fc7, gt_label, axis=1) new_cos_t = new_zy/s margin_symbols.append(mx.symbol.mean(new_cos_t)) else: fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if m>0.0: body = embedding*embedding body = mx.sym.sum_axis(body, axis=1, keepdims=True) body = mx.sym.sqrt(body) body = body*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, body) fc7 = fc7-body elif args.loss_type==3: s = args.margin_s m = args.margin_m assert args.margin==2 or args.margin==4 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(cos_t)) if m>1.0: t = mx.sym.arccos(cos_t) t = t*m body = mx.sym.cos(t) new_zy = body*s if args.margin_verbose>0: new_cos_t = new_zy/s margin_symbols.append(mx.symbol.mean(new_cos_t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body #threshold = math.cos(args.margin_m) #cond_v = cos_t - threshold #cond = mx.symbol.Activation(data=cond_v, act_type='relu') #body = cos_t #for i in xrange(args.margin//2): # body = body*body # body = body*2-1 #new_zy = body*s #zy_keep = zy #new_zy = mx.sym.where(cond, new_zy, zy_keep) #if args.margin_verbose>0: # new_cos_t = new_zy/s # margin_symbols.append(mx.symbol.mean(new_cos_t)) #diff = new_zy - zy #diff = mx.sym.expand_dims(diff, 1) #gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) #body = mx.sym.broadcast_mul(gt_one_hot, diff) #fc7 = fc7+body elif args.loss_type==4: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(cos_t)) if args.output_c2c==0: cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi-m)*m #threshold = 0.0 threshold = math.cos(math.pi-m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t*cos_t body = 1.0-body sin_t = mx.sym.sqrt(body) new_zy = cos_t*cos_m b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy*s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s*mm new_zy = mx.sym.where(cond, new_zy, zy_keep) else: #set c2c as cosm^2 in data.py cos_m = mx.sym.sqrt(c2c_label) sin_m = 1.0-c2c_label sin_m = mx.sym.sqrt(sin_m) body = cos_t*cos_t body = 1.0-body sin_t = mx.sym.sqrt(body) new_zy = cos_t*cos_m b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy*s if args.margin_verbose>0: new_cos_t = new_zy/s margin_symbols.append(mx.symbol.mean(new_cos_t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==5: s = args.margin_s m = args.margin_m assert s>0.0 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: if args.margin_a==1.0 and args.margin_m==0.0: s_m = s*args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==6: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(t)) t_min = mx.sym.min(t) ta = mx.sym.broadcast_div(t_min, t) a1 = args.margin_a r1 = ta-a1 r1 = mx.symbol.Activation(data=r1, act_type='relu') r1 = r1+a1 r2 = mx.symbol.zeros(shape=(args.per_batch_size,)) cond = t-1.0 cond = mx.symbol.Activation(data=cond, act_type='relu') r = mx.sym.where(cond, r2, r1) var_m = r*m t = t+var_m body = mx.sym.cos(t) new_zy = body*s if args.margin_verbose>0: #new_cos_t = new_zy/s #margin_symbols.append(mx.symbol.mean(new_cos_t)) margin_symbols.append(mx.symbol.mean(t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==7: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(t)) var_m = mx.sym.random.uniform(low=args.margin_b, high=args.margin_m, shape=(1,)) t = mx.sym.broadcast_add(t,var_m) body = mx.sym.cos(t) new_zy = body*s if args.margin_verbose>0: #new_cos_t = new_zy/s #margin_symbols.append(mx.symbol.mean(new_cos_t)) margin_symbols.append(mx.symbol.mean(t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==10: #marginal loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') params = [1.2, 0.3, 1.0] n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N #body = mx.sym.sqrt(body) body = body - params[0] mask = mx.sym.Variable('extra') body = body*mask body = body+params[1] #body = mx.sym.maximum(body, 0.0) body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2]) elif args.loss_type==11: #npair loss params = [0.9, 0.2] nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') nembedding = mx.sym.transpose(nembedding) nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.per_identities, args.images_per_identity)) nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512 #nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.images_per_identity, args.per_identities)) #nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512 n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1) n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2) #n1 = [] #n2 = [] #for i in xrange(args.per_identities): # _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1) # _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2) # n1.append(_n1) # n2.append(_n2) #n1 = mx.sym.concat(*n1, dim=0) #n2 = mx.sym.concat(*n2, dim=0) #rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512)) #n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1) #n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2) n1 = mx.symbol.reshape(n1, (args.per_identities, args.emb_size)) n2 = mx.symbol.reshape(n2, (args.per_identities, args.emb_size)) cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pair data_extra = mx.sym.Variable('extra') data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities) mask = cosine_matrix * data_extra #body = mx.sym.mean(mask) fii = mx.sym.sum_axis(mask, axis=1) fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii) fij_fii = mx.sym.exp(fij_fii) row = mx.sym.sum_axis(fij_fii, axis=1) row = mx.sym.log(row) body = mx.sym.mean(row) extra_loss = mx.sym.MakeLoss(body) elif args.loss_type==12: #triplet loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size) ap = anchor - positive an = anchor - negative ap = ap*ap an = an*an ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3) extra_loss = mx.symbol.MakeLoss(triplet_loss) elif args.loss_type==13: #triplet loss with angular margin m = args.margin_m sin_m = math.sin(m) cos_m = math.cos(m) nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size) ap = anchor * positive an = anchor * negative ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) ap = mx.symbol.arccos(ap) an = mx.symbol.arccos(an) triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu') #body = ap*ap #body = 1.0-body #body = mx.symbol.sqrt(body) #body = body*sin_m #ap = ap*cos_m #ap = ap-body #triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) extra_loss = mx.symbol.MakeLoss(triplet_loss) elif args.loss_type==9: #coco loss centroids = [] for i in xrange(args.per_identities): xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity) mean = mx.symbol.mean(xs, axis=0, keepdims=True) mean = mx.symbol.L2Normalization(mean, mode='instance') centroids.append(mean) centroids = mx.symbol.concat(*centroids, dim=0) nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scale fc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities) #extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #extra_loss = mx.symbol.BlockGrad(extra_loss) else: #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type) embedding = embedding * 5 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2 fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=100, name='fc7') #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes, # beta=args.beta, margin=args.margin, scale=args.scale, # op_type='ASoftmax', name='fc7') if args.loss_type<=1 and args.incay>0.0: params = [1.e-10] sel = mx.symbol.argmax(data = fc7, axis=1) sel = (sel==gt_label) norm = embedding*embedding norm = mx.symbol.sum(norm, axis=1) norm = norm+params[0] feature_incay = sel/norm feature_incay = mx.symbol.mean(feature_incay) * args.incay extra_loss = mx.symbol.MakeLoss(feature_incay) #out = softmax #l2_embedding = mx.symbol.L2Normalization(embedding) #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)]) out_list = [mx.symbol.BlockGrad(embedding)] softmax = None if args.loss_type<10: softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) if args.logits_verbose>0: logits = mx.symbol.softmax(data = fc7) logits = mx.sym.pick(logits, gt_label, axis=1) margin_symbols.append(logits) #logit_max = mx.sym.max(logits) #logit_min = mx.sym.min(logits) #margin_symbols.append(logit_max) #margin_symbols.append(logit_min) if softmax is None: out_list.append(mx.sym.BlockGrad(gt_label)) if extra_loss is not None: out_list.append(extra_loss) for _sym in margin_symbols: _sym = mx.sym.BlockGrad(_sym) out_list.append(_sym) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params) def train_net(args): ctx = [] cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip() if len(cvd)>0: for i in xrange(len(cvd.split(','))): ctx.append(mx.gpu(i)) if len(ctx)==0: ctx = [mx.cpu()] print('use cpu') else: print('gpu num:', len(ctx)) prefix = args.prefix prefix_dir = os.path.dirname(prefix) if not os.path.exists(prefix_dir): os.makedirs(prefix_dir) end_epoch = args.end_epoch args.ctx_num = len(ctx) args.num_layers = int(args.network[1:]) print('num_layers', args.num_layers) if args.per_batch_size==0: args.per_batch_size = 128 if args.loss_type==10: args.per_batch_size = 256 args.batch_size = args.per_batch_size*args.ctx_num args.rescale_threshold = 0 args.image_channel = 3 ppatch = [int(x) for x in args.patch.split('_')] assert len(ppatch)==5 os.environ['BETA'] = str(args.beta) data_dir_list = args.data_dir.split(',') if args.loss_type!=12 and args.loss_type!=13: assert len(data_dir_list)==1 data_dir = data_dir_list[0] args.use_val = False path_imgrec = None path_imglist = None val_rec = None prop = face_image.load_property(data_dir) args.num_classes = prop.num_classes image_size = prop.image_size args.image_h = image_size[0] args.image_w = image_size[1] print('image_size', image_size) assert(args.num_classes>0) print('num_classes', args.num_classes) args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3 #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2" path_imgrec = os.path.join(data_dir, "train.rec") val_rec = os.path.join(data_dir, "val.rec") if os.path.exists(val_rec) and args.loss_type<10: args.use_val = True else: val_rec = None #args.use_val = False if args.loss_type==1 and args.num_classes>20000: args.beta_freeze = 5000 args.gamma = 0.06 if args.loss_type<9: assert args.images_per_identity==0 else: if args.images_per_identity==0: if args.loss_type==11: args.images_per_identity = 2 elif args.loss_type==10 or args.loss_type==9: args.images_per_identity = 16 elif args.loss_type==12 or args.loss_type==13: args.images_per_identity = 5 assert args.per_batch_size%3==0 assert args.images_per_identity>=2 args.per_identities = int(args.per_batch_size/args.images_per_identity) print('Called with argument:', args) data_shape = (args.image_channel,image_size[0],image_size[1]) mean = None begin_epoch = 0 base_lr = args.lr base_wd = args.wd base_mom = args.mom if len(args.pretrained)==0: arg_params = None aux_params = None sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) else: vec = args.pretrained.split(',') print('loading', vec) _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1])) sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) data_extra = None hard_mining = False triplet_params = None coco_mode = False if args.loss_type==10: hard_mining = True _shape = (args.batch_size, args.per_batch_size) data_extra = np.full(_shape, -1.0, dtype=np.float32) c = 0 while c0.0: print('use noise sgd') opt = NoiseSGD(scale = args.noise_sgd, learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale) else: opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale) som = 20 if args.loss_type==12 or args.loss_type==13: som = 2 _cb = mx.callback.Speedometer(args.batch_size, som) ver_list = [] ver_name_list = [] for name in args.target.split(','): path = os.path.join(data_dir,name+".bin") if os.path.exists(path): data_set = verification.load_bin(path, image_size) ver_list.append(data_set) ver_name_list.append(name) print('ver', name) def ver_test(nbatch): results = [] for i in xrange(len(ver_list)): acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, data_extra, label_shape) print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm)) #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1)) print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2)) results.append(acc2) return results def val_test(): acc = AccMetric() val_metric = mx.metric.create(acc) val_metric.reset() val_dataiter.reset() for i, eval_batch in enumerate(val_dataiter): model.forward(eval_batch, is_train=False) model.update_metric(val_metric, eval_batch.label) acc_value = val_metric.get_name_value()[0][1] print('VACC: %f'%(acc_value)) highest_acc = [0.0, 0.0] #lfw and target #for i in xrange(len(ver_list)): # highest_acc.append(0.0) global_step = [0] save_step = [0] if len(args.lr_steps)==0: lr_steps = [40000, 60000, 80000] if args.loss_type>=1 and args.loss_type<=7: lr_steps = [100000, 140000, 160000] p = 512.0/args.batch_size for l in xrange(len(lr_steps)): lr_steps[l] = int(lr_steps[l]*p) else: lr_steps = [int(x) for x in args.lr_steps.split(',')] print('lr_steps', lr_steps) def _batch_callback(param): #global global_step global_step[0]+=1 mbatch = global_step[0] for _lr in lr_steps: if mbatch==args.beta_freeze+_lr: opt.lr *= 0.1 print('lr change to', opt.lr) break _cb(param) if mbatch%1000==0: print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch) if mbatch>=0 and mbatch%args.verbose==0: acc_list = ver_test(mbatch) save_step[0]+=1 msave = save_step[0] do_save = False if len(acc_list)>0: lfw_score = acc_list[0] if lfw_score>highest_acc[0]: highest_acc[0] = lfw_score if lfw_score>=0.998: do_save = True if acc_list[-1]>=highest_acc[-1]: highest_acc[-1] = acc_list[-1] if lfw_score>=0.99: do_save = True if args.ckpt==0: do_save = False elif args.ckpt>1: do_save = True #for i in xrange(len(acc_list)): # acc = acc_list[i] # if acc>=highest_acc[i]: # highest_acc[i] = acc # if lfw_score>=0.99: # do_save = True #if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0: # do_save = True if do_save: print('saving', msave) if val_dataiter is not None: val_test() arg, aux = model.get_params() mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux) #if acc>=highest_acc[0]: # lfw_npy = "%s-lfw-%04d" % (prefix, msave) # X = np.concatenate(embeddings_list, axis=0) # print('saving lfw npy', X.shape) # np.save(lfw_npy, X) print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1])) if mbatch<=args.beta_freeze: _beta = args.beta else: move = max(0, mbatch-args.beta_freeze) _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power)) #print('beta', _beta) os.environ['BETA'] = str(_beta) if args.max_steps>0 and mbatch>args.max_steps: sys.exit(0) #epoch_cb = mx.callback.do_checkpoint(prefix, 1) epoch_cb = None #def _epoch_callback(epoch, sym, arg_params, aux_params): # print('epoch-end', epoch) model.fit(train_dataiter, begin_epoch = begin_epoch, num_epoch = end_epoch, eval_data = val_dataiter, eval_metric = eval_metrics, kvstore = 'device', optimizer = opt, #optimizer_params = optimizer_params, initializer = initializer, arg_params = arg_params, aux_params = aux_params, allow_missing = True, batch_end_callback = _batch_callback, epoch_end_callback = epoch_cb ) def main(): #time.sleep(3600*6.5) global args args = parse_args() train_net(args) if __name__ == '__main__': main()