Files
insightface/recognition/train_softmax.py
2019-01-14 18:18:21 +08:00

355 lines
13 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import math
import random
import logging
import sklearn
import pickle
import numpy as np
from image_iter import FaceImageIter
from image_iter import FaceImageIterList
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizer
from config import config, default, generate_config
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
import face_image
sys.path.append(os.path.join(os.path.dirname(__file__), 'eval'))
import verification
sys.path.append(os.path.join(os.path.dirname(__file__), 'symbol'))
import fresnet
import fmobilefacenet
import fmobilenet
logger = logging.getLogger()
logger.setLevel(logging.INFO)
args = None
class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__(
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
label = labels[0]
pred_label = mx.nd.argmax(preds[1], axis=1)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim==2:
label = label[:,0]
label = label.astype('int32').flatten()
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__(
'lossvalue', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
loss = preds[-1].asnumpy()[0]
self.sum_metric += loss
self.num_inst += 1.0
#gt_label = preds[-2].asnumpy()
#print(gt_label)
def parse_args():
parser = argparse.ArgumentParser(description='Train face network')
# general
parser.add_argument('--dataset', default=default.dataset, help='dataset config')
parser.add_argument('--network', default=default.network, help='network config')
parser.add_argument('--loss', default=default.loss, help='loss config')
args, rest = parser.parse_known_args()
generate_config(args.network, args.dataset, args.loss)
parser.add_argument('--models-root', default=default.models_root, help='root directory to save model.')
parser.add_argument('--pretrained', default='', help='pretrained model to load')
parser.add_argument('--ckpt', type=int, default=default.ckpt, help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save')
parser.add_argument('--verbose', type=int, default=default.verbose, help='do verification testing and model saving every verbose batches')
parser.add_argument('--max-steps', type=int, default=0, help='max training batches')
parser.add_argument('--end-epoch', type=int, default=100000, help='training epoch size.')
parser.add_argument('--lr', type=float, default=default.lr, help='start learning rate')
parser.add_argument('--lr-steps', type=str, default=default.lr_steps, help='steps of lr changing')
parser.add_argument('--wd', type=float, default=default.wd, help='weight decay')
parser.add_argument('--mom', type=float, default=default.mom, help='momentum')
parser.add_argument('--frequent', type=int, default=default.frequent, help='')
parser.add_argument('--fc7-wd-mult', type=float, default=1.0, help='weight decay mult for fc7')
parser.add_argument('--fc7-lr-mult', type=float, default=1.0, help='lr mult for fc7')
parser.add_argument("--fc7-no-bias", default=False, action="store_true" , help="fc7 no bias flag")
parser.add_argument('--per-batch-size', type=int, default=default.per_batch_size, help='batch size in each context')
parser.add_argument('--rand-mirror', type=int, default=1, help='if do random mirror in training')
parser.add_argument('--cutoff', type=int, default=0, help='cut off aug')
parser.add_argument('--color', type=int, default=0, help='color jittering aug')
parser.add_argument('--images-filter', type=int, default=0, help='minimum images per identity filter')
parser.add_argument('--ce-loss', default=False, action='store_true', help='if output ce loss')
args = parser.parse_args()
return args
def get_symbol(args):
embedding = eval(config.net_name).get_symbol()
all_label = mx.symbol.Variable('softmax_label')
gt_label = all_label
_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
if config.loss_name=='softmax': #softmax
if args.fc7_no_bias:
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=config.num_classes, name='fc7')
else:
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=config.num_classes, name='fc7')
else:
s = config.loss_s
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=config.num_classes, name='fc7')
if config.loss_m1!=1.0 or config.loss_m2!=0.0 or config.loss_m3!=0.0:
if config.loss_m1==1.0 and config.loss_m2==0.0:
s_m = s*config.loss_m3
gt_one_hot = mx.sym.one_hot(gt_label, depth = config.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
else:
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
if config.loss_m1!=1.0:
t = t*config.loss_m1
if config.loss_m2>0.0:
t = t+config.loss_m2
body = mx.sym.cos(t)
if config.loss_m3>0.0:
body = body - config.loss_m3
new_zy = body*s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = config.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
out_list = [mx.symbol.BlockGrad(embedding)]
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
out_list.append(softmax)
if args.ce_loss:
#ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
body = mx.symbol.SoftmaxActivation(data=fc7)
body = mx.symbol.log(body)
_label = mx.sym.one_hot(gt_label, depth = config.num_classes, on_value = -1.0, off_value = 0.0)
body = body*_label
ce_loss = mx.symbol.sum(body)/args.per_batch_size
out_list.append(mx.symbol.BlockGrad(ce_loss))
out = mx.symbol.Group(out_list)
return out
def train_net(args):
ctx = []
cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
if len(cvd)>0:
for i in xrange(len(cvd.split(','))):
ctx.append(mx.gpu(i))
if len(ctx)==0:
ctx = [mx.cpu()]
print('use cpu')
else:
print('gpu num:', len(ctx))
prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model')
prefix_dir = os.path.dirname(prefix)
print('prefix', prefix)
if not os.path.exists(prefix_dir):
os.makedirs(prefix_dir)
end_epoch = args.end_epoch
args.ctx_num = len(ctx)
args.batch_size = args.per_batch_size*args.ctx_num
args.rescale_threshold = 0
args.image_channel = config.image_shape[2]
data_dir = config.dataset_path
path_imgrec = None
path_imglist = None
image_size = config.image_shape[0:2]
assert len(image_size)==2
assert image_size[0]==image_size[1]
print('image_size', image_size)
print('num_classes', config.num_classes)
path_imgrec = os.path.join(data_dir, "train.rec")
print('Called with argument:', args, config)
data_shape = (args.image_channel,image_size[0],image_size[1])
mean = None
begin_epoch = 0
if len(args.pretrained)==0:
arg_params = None
aux_params = None
sym = get_symbol(args)
if config.net_name=='spherenet':
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
else:
vec = args.pretrained.split(',')
print('loading', vec)
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
sym = get_symbol(args)
#label_name = 'softmax_label'
#label_shape = (args.batch_size,)
model = mx.mod.Module(
context = ctx,
symbol = sym,
)
val_dataiter = None
train_dataiter = FaceImageIter(
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = path_imgrec,
shuffle = True,
rand_mirror = args.rand_mirror,
mean = mean,
cutoff = args.cutoff,
color_jittering = args.color,
images_filter = args.images_filter,
)
metric1 = AccMetric()
eval_metrics = [mx.metric.create(metric1)]
if args.ce_loss:
metric2 = LossValueMetric()
eval_metrics.append( mx.metric.create(metric2) )
if config.net_name=='fresnet' or config.net_name=='fmobilefacenet':
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
else:
initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
#initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
_rescale = 1.0/args.ctx_num
opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
_cb = mx.callback.Speedometer(args.batch_size, args.frequent)
ver_list = []
ver_name_list = []
for name in config.val_targets:
path = os.path.join(data_dir,name+".bin")
if os.path.exists(path):
data_set = verification.load_bin(path, image_size)
ver_list.append(data_set)
ver_name_list.append(name)
print('ver', name)
def ver_test(nbatch):
results = []
for i in xrange(len(ver_list)):
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
#print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
results.append(acc2)
return results
highest_acc = [0.0, 0.0] #lfw and target
#for i in xrange(len(ver_list)):
# highest_acc.append(0.0)
global_step = [0]
save_step = [0]
lr_steps = [int(x) for x in args.lr_steps.split(',')]
print('lr_steps', lr_steps)
def _batch_callback(param):
#global global_step
global_step[0]+=1
mbatch = global_step[0]
for step in lr_steps:
if mbatch==step:
opt.lr *= 0.1
print('lr change to', opt.lr)
break
_cb(param)
if mbatch%1000==0:
print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
if mbatch>=0 and mbatch%args.verbose==0:
acc_list = ver_test(mbatch)
save_step[0]+=1
msave = save_step[0]
do_save = False
is_highest = False
if len(acc_list)>0:
#lfw_score = acc_list[0]
#if lfw_score>highest_acc[0]:
# highest_acc[0] = lfw_score
# if lfw_score>=0.998:
# do_save = True
score = sum(acc_list)
if acc_list[-1]>=highest_acc[-1]:
if acc_list[-1]>highest_acc[-1]:
is_highest = True
else:
if score>=highest_acc[0]:
is_highest = True
highest_acc[0] = score
highest_acc[-1] = acc_list[-1]
#if lfw_score>=0.99:
# do_save = True
if is_highest:
do_save = True
if args.ckpt==0:
do_save = False
elif args.ckpt==2:
do_save = True
elif args.ckpt==3:
msave = 1
if do_save:
print('saving', msave)
arg, aux = model.get_params()
mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
if args.max_steps>0 and mbatch>args.max_steps:
sys.exit(0)
epoch_cb = None
train_dataiter = mx.io.PrefetchingIter(train_dataiter)
model.fit(train_dataiter,
begin_epoch = begin_epoch,
num_epoch = end_epoch,
eval_data = val_dataiter,
eval_metric = eval_metrics,
kvstore = 'device',
optimizer = opt,
#optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
allow_missing = True,
batch_end_callback = _batch_callback,
epoch_end_callback = epoch_cb )
def main():
#time.sleep(3600*6.5)
global args
args = parse_args()
train_net(args)
if __name__ == '__main__':
main()