Files
insightface/src/train_softmax.py
nttstar b3aef14030 tiny
2018-01-21 22:34:26 +08:00

838 lines
32 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import math
import random
import logging
import pickle
import numpy as np
from data import FaceImageIter
from data import FaceImageIterList
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizer
sys.path.append(os.path.join(os.path.dirname(__file__), 'common'))
import face_image
sys.path.append(os.path.join(os.path.dirname(__file__), 'eval'))
sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols'))
import fresnet
import finception_resnet_v2
import fmobilenet
import fmobilenetv2
import fxception
import fdensenet
import fdpn
import fnasnet
#import lfw
import verification
import sklearn
sys.path.append(os.path.join(os.path.dirname(__file__), 'losses'))
import center_loss
logger = logging.getLogger()
logger.setLevel(logging.INFO)
args = None
class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__(
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
if args.loss_type>=2 and args.loss_type<=4 and args.margin_verbose>0:
if self.count%args.ctx_num==0:
mbatch = self.count//args.ctx_num
if mbatch==1 or mbatch%args.margin_verbose==0:
a = 0.0
b = 0.0
if len(preds)>=4:
a = preds[-2].asnumpy()[0]
b = preds[-1].asnumpy()[0]
elif len(preds)==3:
a = preds[-1].asnumpy()[0]
b = a
print('[%d][MARGIN]%f,%f'%(mbatch,a,b))
#loss = preds[2].asnumpy()[0]
#if len(self.losses)==20:
# print('ce loss', sum(self.losses)/len(self.losses))
# self.losses = []
#self.losses.append(loss)
preds = [preds[1]] #use softmax output
for label, pred_label in zip(labels, preds):
#print(pred_label)
#print(label.shape, pred_label.shape)
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy().astype('int32').flatten()
#print(label)
#print('label',label)
#print('pred_label', pred_label)
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__(
'lossvalue', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
loss = preds[-1].asnumpy()[0]
self.sum_metric += loss
self.num_inst += 1.0
gt_label = preds[-2].asnumpy()
#print(gt_label)
def parse_args():
parser = argparse.ArgumentParser(description='Train face network')
# general
parser.add_argument('--data-dir', default='',
help='')
parser.add_argument('--prefix', default='../model/model',
help='directory to save model.')
parser.add_argument('--pretrained', default='',
help='')
parser.add_argument('--retrain', action='store_true', default=False,
help='true means continue training.')
parser.add_argument('--ckpt', type=int, default=1, help='')
parser.add_argument('--network', default='s20', help='')
parser.add_argument('--version-se', type=int, default=0, help='')
parser.add_argument('--version-input', type=int, default=1, help='')
parser.add_argument('--version-output', type=str, default='E', help='')
parser.add_argument('--version-unit', type=int, default=3, help='')
parser.add_argument('--end-epoch', type=int, default=100000,
help='training epoch size.')
parser.add_argument('--lr', type=float, default=0.1,
help='')
parser.add_argument('--wd', type=float, default=0.0005,
help='')
parser.add_argument('--mom', type=float, default=0.9,
help='')
parser.add_argument('--emb-size', type=int, default=512,
help='')
parser.add_argument('--per-batch-size', type=int, default=128,
help='')
parser.add_argument('--margin-m', type=float, default=0.35,
help='')
parser.add_argument('--margin-s', type=float, default=64.0,
help='')
parser.add_argument('--easy-margin', type=int, default=0,
help='')
parser.add_argument('--margin-verbose', type=int, default=0,
help='')
parser.add_argument('--margin', type=int, default=4,
help='')
parser.add_argument('--beta', type=float, default=1000.,
help='')
parser.add_argument('--beta-min', type=float, default=5.,
help='')
parser.add_argument('--beta-freeze', type=int, default=0,
help='')
parser.add_argument('--gamma', type=float, default=0.12,
help='')
parser.add_argument('--power', type=float, default=1.0,
help='')
parser.add_argument('--scale', type=float, default=0.9993,
help='')
parser.add_argument('--center-alpha', type=float, default=0.5, help='')
parser.add_argument('--center-scale', type=float, default=0.003, help='')
parser.add_argument('--images-per-identity', type=int, default=0, help='')
parser.add_argument('--triplet-bag-size', type=int, default=3600, help='')
parser.add_argument('--triplet-alpha', type=float, default=0.3, help='')
parser.add_argument('--triplet-max-ap', type=float, default=0.0, help='')
parser.add_argument('--verbose', type=int, default=2000, help='')
parser.add_argument('--loss-type', type=int, default=1,
help='')
parser.add_argument('--incay', type=float, default=0.0,
help='feature incay')
parser.add_argument('--use-deformable', type=int, default=0,
help='')
parser.add_argument('--rand-mirror', type=int, default=1,
help='')
parser.add_argument('--patch', type=str, default='0_0_96_112_0',
help='')
parser.add_argument('--lr-steps', type=str, default='', help='')
parser.add_argument('--max-steps', type=int, default=0, help='')
parser.add_argument('--target', type=str, default='lfw,cfp_ff,cfp_fp,agedb_30', help='')
args = parser.parse_args()
return args
def get_symbol(args, arg_params, aux_params):
data_shape = (args.image_channel,args.image_h,args.image_w)
image_shape = ",".join([str(x) for x in data_shape])
margin_symbols = []
if args.network[0]=='d':
embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='m':
print('init mobilenet', args.num_layers)
if args.num_layers==1:
embedding = fmobilenet.get_symbol(args.emb_size,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
else:
embedding = fmobilenetv2.get_symbol(args.emb_size)
elif args.network[0]=='i':
print('init inception-resnet-v2', args.num_layers)
embedding = finception_resnet_v2.get_symbol(args.emb_size,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='x':
print('init xception', args.num_layers)
embedding = fxception.get_symbol(args.emb_size,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='p':
print('init dpn', args.num_layers)
embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='n':
print('init nasnet', args.num_layers)
embedding = fnasnet.get_symbol(args.emb_size)
else:
print('init resnet', args.num_layers)
embedding = fresnet.get_symbol(args.emb_size, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
gt_label = mx.symbol.Variable('softmax_label')
assert args.loss_type>=0
extra_loss = None
if args.loss_type==0: #softmax
_weight = mx.symbol.Variable('fc7_weight')
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
elif args.loss_type==1: #sphere
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
weight = _weight,
beta=args.beta, margin=args.margin, scale=args.scale,
beta_min=args.beta_min, verbose=1000, name='fc7')
elif args.loss_type==8: #centerloss, TODO
_weight = mx.symbol.Variable('fc7_weight')
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
print('center-loss', args.center_alpha, args.center_scale)
extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\
num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size)
elif args.loss_type==2:
s = args.margin_s
m = args.margin_m
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
if s>0.0:
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if m>0.0:
if args.margin_verbose>0:
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
margin_symbols.append(mx.symbol.mean(cos_t))
s_m = s*m
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
if args.margin_verbose>0:
new_zy = mx.sym.pick(fc7, gt_label, axis=1)
new_cos_t = new_zy/s
margin_symbols.append(mx.symbol.mean(new_cos_t))
else:
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if m>0.0:
body = embedding*embedding
body = mx.sym.sum_axis(body, axis=1, keepdims=True)
body = mx.sym.sqrt(body)
body = body*m
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, body)
fc7 = fc7-body
elif args.loss_type==3:
s = args.margin_s
m = args.margin_m
assert args.margin==2 or args.margin==4
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(cos_t))
#threshold = math.cos(args.margin_m)
#cond_v = cos_t - threshold
#cond = mx.symbol.Activation(data=cond_v, act_type='relu')
#body = cos_t
#for i in xrange(args.margin//2):
# body = body*body
# body = body*2-1
#new_zy = body*s
#zy_keep = zy
#new_zy = mx.sym.where(cond, new_zy, zy_keep)
#if args.margin_verbose>0:
# new_cos_t = new_zy/s
# margin_symbols.append(mx.symbol.mean(new_cos_t))
#diff = new_zy - zy
#diff = mx.sym.expand_dims(diff, 1)
#gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
#body = mx.sym.broadcast_mul(gt_one_hot, diff)
#fc7 = fc7+body
elif args.loss_type==4:
s = args.margin_s
m = args.margin_m
assert s>0.0
assert m>0.0
assert m<(math.pi/2)
cos_m = math.cos(m)
sin_m = math.sin(m)
mm = math.sin(math.pi-m)*m
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
#threshold = 0.0
threshold = math.cos(math.pi-m)
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(cos_t))
if args.easy_margin:
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
#cond_v = cos_t - 0.4
#cond = mx.symbol.Activation(data=cond_v, act_type='relu')
else:
cond_v = cos_t - threshold
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
#theta = mx.sym.arccos(costheta)
#sintheta = mx.sym.sin(theta)
body = cos_t*cos_t
body = 1.0-body
sin_t = mx.sym.sqrt(body)
new_zy = cos_t*cos_m
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s
if args.easy_margin:
zy_keep = zy
else:
zy_keep = zy - s*mm
new_zy = mx.sym.where(cond, new_zy, zy_keep)
if args.margin_verbose>0:
new_cos_t = new_zy/s
margin_symbols.append(mx.symbol.mean(new_cos_t))
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
elif args.loss_type==10: #marginal loss
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
params = [1.2, 0.3, 1.0]
n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C
n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C
body = mx.sym.broadcast_sub(n1, n2) #N,N,C
body = body * body
body = mx.sym.sum(body, axis=2) # N,N
#body = mx.sym.sqrt(body)
body = body - params[0]
mask = mx.sym.Variable('extra')
body = body*mask
body = body+params[1]
#body = mx.sym.maximum(body, 0.0)
body = mx.symbol.Activation(data=body, act_type='relu')
body = mx.sym.sum(body)
body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size)
extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])
elif args.loss_type==11: #npair loss
params = [0.9, 0.2]
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
nembedding = mx.sym.transpose(nembedding)
nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.per_identities, args.images_per_identity))
nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512
#nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.images_per_identity, args.per_identities))
#nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512
n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1)
n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2)
#n1 = []
#n2 = []
#for i in xrange(args.per_identities):
# _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1)
# _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2)
# n1.append(_n1)
# n2.append(_n2)
#n1 = mx.sym.concat(*n1, dim=0)
#n2 = mx.sym.concat(*n2, dim=0)
#rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512))
#n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1)
#n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2)
n1 = mx.symbol.reshape(n1, (args.per_identities, args.emb_size))
n2 = mx.symbol.reshape(n2, (args.per_identities, args.emb_size))
cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pair
data_extra = mx.sym.Variable('extra')
data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities)
mask = cosine_matrix * data_extra
#body = mx.sym.mean(mask)
fii = mx.sym.sum_axis(mask, axis=1)
fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii)
fij_fii = mx.sym.exp(fij_fii)
row = mx.sym.sum_axis(fij_fii, axis=1)
row = mx.sym.log(row)
body = mx.sym.mean(row)
extra_loss = mx.sym.MakeLoss(body)
elif args.loss_type==12: #triplet loss
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)
positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)
negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)
ap = anchor - positive
an = anchor - negative
ap = ap*ap
an = an*an
ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')
triplet_loss = mx.symbol.mean(triplet_loss)
#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
extra_loss = mx.symbol.MakeLoss(triplet_loss)
elif args.loss_type==9: #coco loss
centroids = []
for i in xrange(args.per_identities):
xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity)
mean = mx.symbol.mean(xs, axis=0, keepdims=True)
mean = mx.symbol.L2Normalization(mean, mode='instance')
centroids.append(mean)
centroids = mx.symbol.concat(*centroids, dim=0)
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scale
fc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities)
#extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
#extra_loss = mx.symbol.BlockGrad(extra_loss)
else:
#embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
embedding = embedding * 5
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
weight = _weight,
beta=args.beta, margin=args.margin, scale=args.scale,
beta_min=args.beta_min, verbose=100, name='fc7')
#fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,
# beta=args.beta, margin=args.margin, scale=args.scale,
# op_type='ASoftmax', name='fc7')
if args.loss_type<=1 and args.incay>0.0:
params = [1.e-10]
sel = mx.symbol.argmax(data = fc7, axis=1)
sel = (sel==gt_label)
norm = embedding*embedding
norm = mx.symbol.sum(norm, axis=1)
norm = norm+params[0]
feature_incay = sel/norm
feature_incay = mx.symbol.mean(feature_incay) * args.incay
extra_loss = mx.symbol.MakeLoss(feature_incay)
#out = softmax
#l2_embedding = mx.symbol.L2Normalization(embedding)
#ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
#out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])
out_list = [mx.symbol.BlockGrad(embedding)]
softmax = None
if args.loss_type<10:
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
out_list.append(softmax)
if softmax is None:
out_list.append(mx.sym.BlockGrad(gt_label))
if extra_loss is not None:
out_list.append(extra_loss)
for _sym in margin_symbols:
_sym = mx.sym.BlockGrad(_sym)
out_list.append(_sym)
out = mx.symbol.Group(out_list)
return (out, arg_params, aux_params)
def train_net(args):
ctx = []
cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
if len(cvd)>0:
for i in xrange(len(cvd.split(','))):
ctx.append(mx.gpu(i))
if len(ctx)==0:
ctx = [mx.cpu()]
print('use cpu')
else:
print('gpu num:', len(ctx))
prefix = args.prefix
prefix_dir = os.path.dirname(prefix)
if not os.path.exists(prefix_dir):
os.makedirs(prefix_dir)
end_epoch = args.end_epoch
args.ctx_num = len(ctx)
args.num_layers = int(args.network[1:])
print('num_layers', args.num_layers)
if args.per_batch_size==0:
args.per_batch_size = 128
if args.loss_type==10:
args.per_batch_size = 256
args.batch_size = args.per_batch_size*args.ctx_num
args.rescale_threshold = 0
args.image_channel = 3
ppatch = [int(x) for x in args.patch.split('_')]
assert len(ppatch)==5
os.environ['BETA'] = str(args.beta)
data_dir_list = args.data_dir.split(',')
if args.loss_type!=12:
assert len(data_dir_list)==1
data_dir = data_dir_list[0]
args.use_val = False
path_imgrec = None
path_imglist = None
val_rec = None
prop = face_image.load_property(data_dir)
args.num_classes = prop.num_classes
image_size = prop.image_size
args.image_h = image_size[0]
args.image_w = image_size[1]
print('image_size', image_size)
assert(args.num_classes>0)
print('num_classes', args.num_classes)
args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
path_imgrec = os.path.join(data_dir, "train.rec")
val_rec = os.path.join(data_dir, "val.rec")
if os.path.exists(val_rec) and args.loss_type<10:
args.use_val = True
else:
val_rec = None
#args.use_val = False
if args.loss_type==1 and args.num_classes>20000:
args.beta_freeze = 5000
args.gamma = 0.06
if args.loss_type<9:
assert args.images_per_identity==0
else:
if args.images_per_identity==0:
if args.loss_type==11:
args.images_per_identity = 2
elif args.loss_type==10 or args.loss_type==9:
args.images_per_identity = 16
elif args.loss_type==12:
args.images_per_identity = 5
assert args.per_batch_size%3==0
assert args.images_per_identity>=2
args.per_identities = int(args.per_batch_size/args.images_per_identity)
print('Called with argument:', args)
data_shape = (args.image_channel,image_size[0],image_size[1])
mean = None
begin_epoch = 0
base_lr = args.lr
base_wd = args.wd
base_mom = args.mom
if len(args.pretrained)==0:
arg_params = None
aux_params = None
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
else:
vec = args.pretrained.split(',')
print('loading', vec)
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
data_extra = None
hard_mining = False
triplet_params = None
coco_mode = False
if args.loss_type==10:
hard_mining = True
_shape = (args.batch_size, args.per_batch_size)
data_extra = np.full(_shape, -1.0, dtype=np.float32)
c = 0
while c<args.batch_size:
a = 0
while a<args.per_batch_size:
b = a+args.images_per_identity
data_extra[(c+a):(c+b),a:b] = 1.0
#print(c+a, c+b, a, b)
a = b
c += args.per_batch_size
elif args.loss_type==11:
data_extra = np.zeros( (args.batch_size, args.per_identities), dtype=np.float32)
c = 0
while c<args.batch_size:
for i in xrange(args.per_identities):
data_extra[c+i][i] = 1.0
c+=args.per_batch_size
elif args.loss_type==12:
triplet_params = [args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap]
elif args.loss_type==9:
coco_mode = True
label_name = 'softmax_label'
if data_extra is None:
model = mx.mod.Module(
context = ctx,
symbol = sym,
)
else:
data_names = ('data', 'extra')
#label_name = ''
model = mx.mod.Module(
context = ctx,
symbol = sym,
data_names = data_names,
label_names = (label_name,),
)
if args.use_val:
val_dataiter = FaceImageIter(
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = val_rec,
#path_imglist = val_path,
shuffle = False,
rand_mirror = False,
mean = mean,
ctx_num = args.ctx_num,
data_extra = data_extra,
)
else:
val_dataiter = None
if len(data_dir_list)==1 and args.loss_type!=12:
train_dataiter = FaceImageIter(
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = path_imgrec,
shuffle = True,
rand_mirror = args.rand_mirror,
mean = mean,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_extra = data_extra,
hard_mining = hard_mining,
triplet_params = triplet_params,
coco_mode = coco_mode,
mx_model = model,
label_name = label_name,
)
else:
iter_list = []
for _data_dir in data_dir_list:
_path_imgrec = os.path.join(_data_dir, "train.rec")
_dataiter = FaceImageIter(
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = _path_imgrec,
shuffle = True,
rand_mirror = args.rand_mirror,
mean = mean,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_extra = data_extra,
hard_mining = hard_mining,
triplet_params = triplet_params,
coco_mode = coco_mode,
mx_model = model,
label_name = label_name,
)
iter_list.append(_dataiter)
iter_list.append(_dataiter)
train_dataiter = FaceImageIterList(iter_list)
if args.loss_type<10:
_metric = AccMetric()
else:
_metric = LossValueMetric()
eval_metrics = [mx.metric.create(_metric)]
if args.network[0]=='r':
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
elif args.network[0]=='i' or args.network[0]=='x':
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
else:
initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
_rescale = 1.0/args.ctx_num
opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
som = 20
if args.loss_type==12:
som = 2
_cb = mx.callback.Speedometer(args.batch_size, som)
ver_list = []
ver_name_list = []
for name in args.target.split(','):
path = os.path.join(data_dir,name+".bin")
if os.path.exists(path):
data_set = verification.load_bin(path, image_size)
ver_list.append(data_set)
ver_name_list.append(name)
print('ver', name)
def ver_test(nbatch):
results = []
for i in xrange(len(ver_list)):
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, data_extra)
print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
#print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
results.append(acc2)
return results
def val_test():
acc = AccMetric()
val_metric = mx.metric.create(acc)
val_metric.reset()
val_dataiter.reset()
for i, eval_batch in enumerate(val_dataiter):
model.forward(eval_batch, is_train=False)
model.update_metric(val_metric, eval_batch.label)
acc_value = val_metric.get_name_value()[0][1]
print('VACC: %f'%(acc_value))
highest_acc = [0.0, 0.0] #lfw and target
#for i in xrange(len(ver_list)):
# highest_acc.append(0.0)
global_step = [0]
save_step = [0]
if len(args.lr_steps)==0:
lr_steps = [40000, 60000, 80000]
if args.loss_type>=1 and args.loss_type<=5:
lr_steps = [100000, 140000, 160000]
p = 512.0/args.batch_size
for l in xrange(len(lr_steps)):
lr_steps[l] = int(lr_steps[l]*p)
else:
lr_steps = [int(x) for x in args.lr_steps.split(',')]
print('lr_steps', lr_steps)
def _batch_callback(param):
#global global_step
global_step[0]+=1
mbatch = global_step[0]
for _lr in lr_steps:
if mbatch==args.beta_freeze+_lr:
opt.lr *= 0.1
print('lr change to', opt.lr)
break
_cb(param)
if mbatch%1000==0:
print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
if mbatch>=0 and mbatch%args.verbose==0:
acc_list = ver_test(mbatch)
save_step[0]+=1
msave = save_step[0]
do_save = False
if len(acc_list)>0:
lfw_score = acc_list[0]
if lfw_score>highest_acc[0]:
highest_acc[0] = lfw_score
if lfw_score>=0.998:
do_save = True
if acc_list[-1]>=highest_acc[-1]:
highest_acc[-1] = acc_list[-1]
if lfw_score>=0.99:
do_save = True
if args.ckpt==0:
do_save = False
elif args.ckpt>1:
do_save = True
#for i in xrange(len(acc_list)):
# acc = acc_list[i]
# if acc>=highest_acc[i]:
# highest_acc[i] = acc
# if lfw_score>=0.99:
# do_save = True
#if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
# do_save = True
if do_save:
print('saving', msave)
if val_dataiter is not None:
val_test()
arg, aux = model.get_params()
mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
#if acc>=highest_acc[0]:
# lfw_npy = "%s-lfw-%04d" % (prefix, msave)
# X = np.concatenate(embeddings_list, axis=0)
# print('saving lfw npy', X.shape)
# np.save(lfw_npy, X)
print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
if mbatch<=args.beta_freeze:
_beta = args.beta
else:
move = max(0, mbatch-args.beta_freeze)
_beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
#print('beta', _beta)
os.environ['BETA'] = str(_beta)
if args.max_steps>0 and mbatch>args.max_steps:
sys.exit(0)
#epoch_cb = mx.callback.do_checkpoint(prefix, 1)
epoch_cb = None
#def _epoch_callback(epoch, sym, arg_params, aux_params):
# print('epoch-end', epoch)
model.fit(train_dataiter,
begin_epoch = begin_epoch,
num_epoch = end_epoch,
eval_data = val_dataiter,
eval_metric = eval_metrics,
kvstore = 'device',
optimizer = opt,
#optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
allow_missing = True,
batch_end_callback = _batch_callback,
epoch_end_callback = epoch_cb )
def main():
#time.sleep(3600*6.5)
global args
args = parse_args()
train_net(args)
if __name__ == '__main__':
main()