Files
insightface/src/train_softmax.py

668 lines
25 KiB
Python
Raw Normal View History

2017-11-14 15:10:51 +08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import math
import random
import logging
import numpy as np
from data import FaceIter
from data import FaceImageIter2
from data import FaceImageIter4
from data import FaceImageIter5
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizer
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
import spherenet
import marginalnet
import inceptions
import xception
import lfw
import sklearn
from sklearn.decomposition import PCA
#from center_loss import *
2017-11-16 14:29:48 +08:00
#import resnet_dcn
#import asoftmax
2017-11-14 15:10:51 +08:00
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__(
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
#loss = preds[2].asnumpy()[0]
#if len(self.losses)==20:
# print('ce loss', sum(self.losses)/len(self.losses))
# self.losses = []
#self.losses.append(loss)
preds = [preds[1]] #use softmax output
for label, pred_label in zip(labels, preds):
#print(pred_label)
#print(label.shape, pred_label.shape)
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy().astype('int32').flatten()
#print(label)
#print('label',label)
#print('pred_label', pred_label)
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
def parse_args():
parser = argparse.ArgumentParser(description='Train face network')
# general
parser.add_argument('--prefix', default='../model/spherefacei',
help='directory to save model.')
parser.add_argument('--pretrained', default='../model/resnet-152',
help='')
parser.add_argument('--network', default='s20',
help='')
parser.add_argument('--load-epoch', type=int, default=0,
help='load epoch.')
parser.add_argument('--end-epoch', type=int, default=1000,
help='training epoch size.')
parser.add_argument('--retrain', action='store_true', default=False,
help='true means continue training.')
parser.add_argument('--lr', type=float, default=0.1,
help='')
parser.add_argument('--images-per-identity', type=int, default=16,
help='')
parser.add_argument('--embedding-dim', type=int, default=512,
help='')
parser.add_argument('--per-batch-size', type=int, default=0,
help='')
parser.add_argument('--margin', type=int, default=4,
help='')
parser.add_argument('--beta', type=float, default=1000.,
help='')
parser.add_argument('--beta-min', type=float, default=5.,
help='')
parser.add_argument('--beta-freeze', type=int, default=0,
help='')
parser.add_argument('--gamma', type=float, default=0.12,
help='')
parser.add_argument('--power', type=float, default=1.0,
help='')
parser.add_argument('--scale', type=float, default=0.9993,
help='')
parser.add_argument('--verbose', type=int, default=1000,
help='')
parser.add_argument('--loss-type', type=int, default=1,
help='')
2017-11-16 14:29:48 +08:00
parser.add_argument('--incay', action='store_true', default=False,
help='feature incay')
2017-11-14 15:10:51 +08:00
parser.add_argument('--use-deformable', type=int, default=0,
help='')
parser.add_argument('--patch', type=str, default='0_0_96_112_0',
help='')
parser.add_argument('--lr-steps', type=str, default='',
help='')
args = parser.parse_args()
return args
def get_symbol(args, arg_params, aux_params):
if args.retrain:
new_args = arg_params
else:
new_args = None
data_shape = (args.image_channel,112,96)
image_shape = ",".join([str(x) for x in data_shape])
if args.network[0]=='s':
embedding = spherenet.get_symbol(512, args.num_layers)
elif args.network[0]=='m':
print('init marginal', args.num_layers)
embedding = marginalnet.get_symbol(512, args.num_layers)
elif args.network[0]=='i':
print('init inception', args.num_layers)
embedding,_ = inceptions.get_symbol_irv2(512)
elif args.network[0]=='x':
print('init xception', args.num_layers)
embedding,_ = xception.get_xception_symbol(512)
else:
print('init resnet', args.num_layers)
_,_,embedding,_ = resnet_dcn.get_symbol(512, args.num_layers)
gt_label = mx.symbol.Variable('softmax_label')
assert args.loss_type>=0
2017-11-16 14:29:48 +08:00
extra_loss = None
2017-11-14 15:10:51 +08:00
if args.loss_type==0:
_weight = mx.symbol.Variable('fc7_weight')
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
elif args.loss_type==1:
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
weight = _weight,
beta=args.beta, margin=args.margin, scale=args.scale,
beta_min=args.beta_min, verbose=100, name='fc7')
elif args.loss_type==10:
_weight = mx.symbol.Variable('fc7_weight')
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
params = [1.2, 0.3, 1.0]
n1 = mx.sym.expand_dims(nembedding, axis=1)
n2 = mx.sym.expand_dims(nembedding, axis=0)
body = mx.sym.broadcast_sub(n1, n2) #N,N,C
body = body * body
body = mx.sym.sum(body, axis=2) # N,N
#body = mx.sym.sqrt(body)
body = body - params[0]
mask = mx.sym.Variable('extra')
body = body*mask
body = body+params[1]
#body = mx.sym.maximum(body, 0.0)
body = mx.symbol.Activation(data=body, act_type='relu')
body = mx.sym.sum(body)
body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size)
extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])
elif args.loss_type==11:
_weight = mx.symbol.Variable('fc7_weight')
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
params = [0.9, 0.2]
nembedding = mx.symbol.slice_axis(embedding, axis=0, begin=0, end=args.images_per_identity)
nembedding = mx.symbol.L2Normalization(nembedding, mode='instance', name='fc1n')
n1 = mx.sym.expand_dims(nembedding, axis=1)
n2 = mx.sym.expand_dims(nembedding, axis=0)
body = mx.sym.broadcast_sub(n1, n2) #N,N,C
body = body * body
body = mx.sym.sum(body, axis=2) # N,N
body = body - params[0]
body = mx.symbol.Activation(data=body, act_type='relu')
body = mx.sym.sum(body)
n = args.images_per_identity
body = body/(n*n-n)
extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[1])
#extra_loss = None
else:
#embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
embedding = embedding * 5
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
weight = _weight,
beta=args.beta, margin=args.margin, scale=args.scale,
beta_min=args.beta_min, verbose=100, name='fc7')
#fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,
# beta=args.beta, margin=args.margin, scale=args.scale,
# op_type='ASoftmax', name='fc7')
if args.loss_type>=args.rescale_threshold:
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
else:
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax')
2017-11-16 14:29:48 +08:00
if args.loss_type<=1 and args.incay:
params = [1.e-10, 0.01]
sel = mx.symbol.argmax(data = fc7, axis=1)
sel = (sel==gt_label)
norm = embedding*embedding
norm = mx.symbol.sum(norm, axis=1)
norm += params[0]
feature_incay = sel/norm
feature_incay = mx.symbol.mean(feature_incay) * params[1]
extra_loss = mx.symbol.MakeLoss(feature_incay)
2017-11-14 15:10:51 +08:00
#out = softmax
#l2_embedding = mx.symbol.L2Normalization(embedding)
#ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
#out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])
2017-11-16 14:29:48 +08:00
if extra_loss is not None:
2017-11-14 15:10:51 +08:00
out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, extra_loss])
else:
out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax])
return (out, new_args, aux_params)
def train_net(args):
ctx = []
cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
if len(cvd)>0:
for i in xrange(len(cvd.split(','))):
ctx.append(mx.gpu(i))
if len(ctx)==0:
ctx = [mx.cpu()]
print('use cpu')
else:
print('gpu num:', len(ctx))
prefix = "%s-%s-p%s" % (args.prefix, args.network, args.patch)
end_epoch = args.end_epoch
pretrained = args.pretrained
load_epoch = args.load_epoch
args.ctx_num = len(ctx)
args.num_layers = int(args.network[1:])
print('num_layers', args.num_layers)
if args.per_batch_size==0:
args.per_batch_size = 128
if args.network[0]=='r':
args.per_batch_size = 128
else:
if args.num_layers>=64:
args.per_batch_size = 120
if args.ctx_num==2:
args.per_batch_size *= 2
elif args.ctx_num==3:
args.per_batch_size = 170
if args.network[0]=='m':
args.per_batch_size = 128
args.batch_size = args.per_batch_size*args.ctx_num
args.rescale_threshold = 0
args.image_channel = 3
ppatch = [int(x) for x in args.patch.split('_')]
assert len(ppatch)==5
#if args.patch%2==1:
# args.image_channel = 1
#os.environ['GLOBAL_STEP'] = "0"
os.environ['BETA'] = str(args.beta)
args.use_val = False
path_imgrec = None
val_rec = None
val_path = None
path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new"
#path_imglist = "/raid5data/dplearn/faceinsight_align_webface_clean.lst.new"
args.num_classes = 10572 #webface
path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
args.num_classes = 81017
path_imgrec = "/opt/jiaguo/faces_celeb/train.rec"
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst3"
#args.num_classes = 81013
path_imglist = "/raid5data/dplearn/faces_normed/train.lst"
args.num_classes = 82395
2017-11-16 14:29:48 +08:00
args.use_val = False
2017-11-14 15:10:51 +08:00
val_path = "/raid5data/dplearn/faces_normed/val.lst"
path_imgrec = "/opt/jiaguo/faces_normed/train.rec"
val_rec = "/opt/jiaguo/faces_normed/val.rec"
if args.loss_type==1 and args.num_classes>40000:
args.beta_freeze = 5000
args.gamma = 0.06
print('Called with argument:', args)
data_shape = (args.image_channel,112,96)
mean = [127.5,127.5,127.5]
2017-11-16 14:29:48 +08:00
if args.use_val:
2017-11-14 15:10:51 +08:00
val_dataiter = FaceImageIter2(
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = val_rec,
path_imglist = val_path,
shuffle = False,
exclude_lfw = False,
rand_mirror = False,
mean = mean,
patch = ppatch,
)
else:
val_dataiter = None
begin_epoch = 0
base_lr = args.lr
base_wd = 0.0005
base_mom = 0.9
if not args.retrain:
#load and initialize params
#print(pretrained)
#_, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)
arg_params = None
aux_params = None
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
#arg_params, aux_params = load_param(pretrained, epoch, convert=True)
data_shape_dict = {'data': (args.batch_size,)+data_shape, 'softmax_label': (args.batch_size,)}
if args.network[0]=='s':
arg_params, aux_params = spherenet.init_weights(sym, data_shape_dict, args.num_layers)
elif args.network[0]=='m':
arg_params, aux_params = marginalnet.init_weights(sym, data_shape_dict, args.num_layers)
#resnet_dcn.init_weights(sym, data_shape_dict, arg_params, aux_params)
else:
#sym, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)
_, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
#begin_epoch = load_epoch
#end_epoch = begin_epoch+10
#base_wd = 0.00005
if args.loss_type!=10:
model = mx.mod.Module(
context = ctx,
symbol = sym,
)
else:
data_names = ('data', 'extra')
model = mx.mod.Module(
context = ctx,
symbol = sym,
data_names = data_names,
)
2017-11-16 14:29:48 +08:00
if args.loss_type<=9:
2017-11-14 15:10:51 +08:00
train_dataiter = FaceImageIter2(
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = path_imgrec,
path_imglist = path_imglist,
shuffle = True,
exclude_lfw = False,
rand_mirror = True,
brightness = 0.4,
contrast = 0.4,
saturation = 0.4,
pca_noise = 0.1,
mean = mean,
patch = ppatch,
)
elif args.loss_type==10:
train_dataiter = FaceImageIter4(
batch_size = args.batch_size,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_shape = data_shape,
path_imglist = path_imglist,
shuffle = True,
rand_mirror = True,
exclude_lfw = False,
mean = mean,
patch = ppatch,
use_extra = True,
model = model,
)
elif args.loss_type==11:
train_dataiter = FaceImageIter5(
batch_size = args.batch_size,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_shape = data_shape,
path_imglist = path_imglist,
shuffle = True,
rand_mirror = True,
exclude_lfw = False,
mean = mean,
patch = ppatch,
)
#args.epoch_size = int(math.ceil(train_dataiter.num_samples()/args.batch_size))
#_dice = DiceMetric()
_acc = AccMetric()
eval_metrics = [mx.metric.create(_acc)]
# rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
#for child_metric in [fcn_loss_metric]:
# eval_metrics.add(child_metric)
# callback
#batch_end_callback = callback.Speedometer(input_batch_size, frequent=args.frequent)
#epoch_end_callback = mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True)
# decide learning rate
#lr_step = '10,20,30'
#train_size = 4848
#nrof_batch_in_epoch = int(train_size/input_batch_size)
#print('nrof_batch_in_epoch:', nrof_batch_in_epoch)
#lr_factor = 0.1
#lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
#lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
#lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
#lr_iters = [int(epoch * train_size / batch_size) for epoch in lr_epoch_diff]
#print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters
#lr_scheduler = MultiFactorScheduler(lr_iters, lr_factor)
# optimizer
#optimizer_params = {'momentum': 0.9,
# 'wd': 0.0005,
# 'learning_rate': base_lr,
# 'rescale_grad': 1.0,
# 'clip_gradient': None}
if args.network[0]=='r':
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
elif args.network[0]=='i' or args.network[0]=='x':
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
else:
initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
_rescale = 1.0/args.batch_size
if args.loss_type>=args.rescale_threshold:
_rescale = 1.0/args.ctx_num
#_rescale = 1.0
opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
#opt = optimizer.RMSProp(learning_rate=base_lr, wd=base_wd, rescale_grad=_rescale)
#opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=_rescale)
#opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=1.0)
_cb = mx.callback.Speedometer(args.batch_size, 10)
lfw_dir = '/raid5data/dplearn/lfw_mtcnn2'
lfw_pairs = lfw.read_pairs(os.path.join(lfw_dir, 'pairs.txt'))
lfw_paths, issame_list = lfw.get_paths(lfw_dir, lfw_pairs, 'jpg')
imgs = []
lfw_data_list = []
for flip in [0,1]:
lfw_data = nd.empty((len(lfw_paths), args.image_channel, 112, 96))
lfw_data_list.append(lfw_data)
i = 0
for path in lfw_paths:
with open(path, 'rb') as fin:
_bin = fin.read()
if ppatch[4]%2==1:
img = mx.image.imdecode(_bin, flag=0)
if img.shape[2]<args.image_channel:
img = nd.broadcast_to(img, (img.shape[0], img.shape[1], 3))
else:
img = mx.image.imdecode(_bin)
img = nd.transpose(img, axes=(2, 0, 1))
img = img.asnumpy()
#print(img.shape)
if mean is not None:
img = img.astype(np.float32)
img -= np.array(mean, dtype=np.float32).reshape(args.image_channel,1,1)
img *= 0.0078125
for flip in [0,1]:
_img = img.copy()
if flip==1:
#_img = _img.asnumpy()
for c in xrange(_img.shape[0]):
_img[c,:,:] = np.fliplr(_img[c,:,:])
#_img = nd.array( _img )
#print(img.shape)
nimg = np.zeros(_img.shape, dtype=np.float32)
nimg[:,ppatch[1]:ppatch[3],ppatch[0]:ppatch[2]] = _img[:, ppatch[1]:ppatch[3], ppatch[0]:ppatch[2]]
lfw_data_list[flip][i][:] = nd.array(nimg)
i+=1
if i%1000==0:
print('loading lfw', i)
print(lfw_data_list[0].shape)
print(lfw_data_list[1].shape)
def lfw_test(nbatch):
print('testing lfw..')
#GLOBAL_STEP = nbatch
#return 0.1
embeddings_list = []
for i in xrange( len(lfw_data_list) ):
lfw_data = lfw_data_list[i]
embeddings = None
ba = 0
_ctx = ctx[0]
while ba<lfw_data.shape[0]:
bb = min(ba+args.batch_size, lfw_data.shape[0])
_data = nd.slice_axis(lfw_data, axis=0, begin=ba, end=bb)
_label = nd.ones( (bb-ba,) )
#print(_data.shape, _label.shape)
db = mx.io.DataBatch(data=(_data,), label=(_label,))
model.forward(db, is_train=False)
net_out = model.get_outputs()
#_arg, _aux = model.get_params()
#__arg = {}
#for k,v in _arg.iteritems():
# __arg[k] = v.as_in_context(_ctx)
#_arg = __arg
#_arg["data"] = _data.as_in_context(_ctx)
#_arg["softmax_label"] = _label.as_in_context(_ctx)
#for k,v in _arg.iteritems():
# print(k,v.context)
#exe = sym.bind(_ctx, _arg ,args_grad=None, grad_req="null", aux_states=_aux)
#exe.forward(is_train=False)
#net_out = exe.outputs
_embeddings = net_out[0].asnumpy()
#print(_embeddings.shape)
if embeddings is None:
embeddings = np.zeros( (lfw_data.shape[0], _embeddings.shape[1]) )
embeddings[ba:bb,:] = _embeddings
ba = bb
embeddings_list.append(embeddings)
_xnorm = 0.0
_xnorm_cnt = 0
for embed in embeddings_list:
for i in xrange(embed.shape[0]):
_em = embed[i]
_norm=np.linalg.norm(_em)
#print(_em.shape, _norm)
_xnorm+=_norm
_xnorm_cnt+=1
_xnorm /= _xnorm_cnt
print('[%d]XNorm: %f' % (nbatch, _xnorm))
acc_list = []
embeddings = embeddings_list[0].copy()
embeddings = sklearn.preprocessing.normalize(embeddings)
_, _, accuracy, val, val_std, far = lfw.evaluate(embeddings, issame_list, nrof_folds=10)
acc_list.append(np.mean(accuracy))
print('[%d]Accuracy: %1.5f+-%1.5f' % (nbatch, np.mean(accuracy), np.std(accuracy)))
#print('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val, val_std, far))
#embeddings = np.concatenate(embeddings_list, axis=1)
embeddings = embeddings_list[0] + embeddings_list[1]
embeddings = sklearn.preprocessing.normalize(embeddings)
print(embeddings.shape)
_, _, accuracy, val, val_std, far = lfw.evaluate(embeddings, issame_list, nrof_folds=10)
acc_list.append(np.mean(accuracy))
print('[%d]Accuracy-Flip: %1.5f+-%1.5f' % (nbatch, np.mean(accuracy), np.std(accuracy)))
racc = acc_list[1]
#racc = max(*acc_list)
#print('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val, val_std, far))
#pca = PCA(n_components=128)
#embeddings = pca.fit_transform(embeddings)
#embeddings = sklearn.preprocessing.normalize(embeddings)
#print(embeddings.shape)
#_, _, accuracy, val, val_std, far = lfw.evaluate(embeddings, issame_list, nrof_folds=10)
#acc_list.append(np.mean(accuracy))
#print('[%d]Accuracy-PCA: %1.3f+-%1.3f' % (nbatch, np.mean(accuracy), np.std(accuracy)))
#print('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val, val_std, far))
return racc, embeddings_list
#global_step = 0
highest_acc = [0.0]
last_save_acc = [0.0]
global_step = [0]
save_step = [0]
if len(args.lr_steps)==0:
2017-11-16 14:29:48 +08:00
#lr_steps = [40000, 70000, 90000]
lr_steps = [30000, 50000, 70000, 90000]
2017-11-14 15:10:51 +08:00
if args.loss_type==1:
lr_steps = [70000, 100000]
else:
lr_steps = [int(x) for x in args.lr_steps.split(',')]
print('lr_steps', lr_steps)
def _batch_callback(param):
#global global_step
global_step[0]+=1
mbatch = global_step[0]
for _lr in lr_steps:
if mbatch==args.beta_freeze+_lr:
opt.lr *= 0.1
print('lr change to', opt.lr)
break
_cb(param)
if mbatch%1000==0:
print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
#os.environ['GLOBAL_STEP'] = str(mbatch)
if mbatch>=0 and mbatch%args.verbose==0:
acc, embeddings_list = lfw_test(mbatch)
save_step[0]+=1
msave = save_step[0]
2017-11-16 14:29:48 +08:00
do_save = False
2017-11-14 15:10:51 +08:00
if acc>=highest_acc[0]:
highest_acc[0] = acc
2017-11-16 14:29:48 +08:00
if acc>=0.995:
do_save = True
if mbatch>lr_steps[-1] and msave%5==0:
do_save = True
if do_save:
print('saving', msave)
arg, aux = model.get_params()
mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
#lfw_npy = "%s-lfw-%04d" % (prefix, msave)
#X = np.concatenate(embeddings_list, axis=0)
#print(X.shape)
#np.save(lfw_npy, X)
2017-11-14 15:10:51 +08:00
print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[0]))
if mbatch<=args.beta_freeze:
_beta = args.beta
else:
move = max(0, mbatch-args.beta_freeze)
_beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
#_beta = max(args.beta_min, args.beta*math.pow(0.7, move//500))
#print('beta', _beta)
os.environ['BETA'] = str(_beta)
#epoch_cb = mx.callback.do_checkpoint(prefix, 1)
epoch_cb = None
#def _epoch_callback(epoch, sym, arg_params, aux_params):
# print('epoch-end', epoch)
model.fit(train_dataiter,
begin_epoch = begin_epoch,
num_epoch = end_epoch,
eval_data = val_dataiter,
eval_metric = eval_metrics,
kvstore = 'device',
optimizer = opt,
#optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
allow_missing = True,
batch_end_callback = _batch_callback,
epoch_end_callback = epoch_cb )
def main():
#time.sleep(3600*6.5)
args = parse_args()
train_net(args)
if __name__ == '__main__':
main()