Files
insightface/src/train_softmax.py

775 lines
30 KiB
Python
Raw Normal View History

2017-11-14 15:10:51 +08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import math
import random
import logging
2017-12-05 19:59:25 +08:00
import pickle
2017-11-14 15:10:51 +08:00
import numpy as np
2017-11-29 17:04:05 +08:00
from data import FaceImageIter
2017-12-21 13:08:52 +08:00
from data import FaceImageIterList
2017-11-14 15:10:51 +08:00
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizer
2017-12-08 13:54:45 +08:00
sys.path.append(os.path.join(os.path.dirname(__file__), 'common'))
import face_image
2017-11-29 17:04:05 +08:00
sys.path.append(os.path.join(os.path.dirname(__file__), 'eval'))
2017-11-29 17:12:26 +08:00
sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols'))
2017-11-30 19:02:04 +08:00
import fresnet
import finception_resnet_v2
2017-12-06 19:46:33 +08:00
import fmobilenet
2017-12-07 15:09:34 +08:00
import fxception
2017-12-07 16:43:28 +08:00
import fdensenet
2017-12-11 13:03:18 +08:00
import fdpn
2017-12-23 19:08:22 +08:00
import fnasnet
2017-12-05 19:59:25 +08:00
#import lfw
import verification
2017-11-14 15:10:51 +08:00
import sklearn
2017-12-09 21:34:27 +08:00
sys.path.append(os.path.join(os.path.dirname(__file__), 'losses'))
import center_loss
2017-11-14 15:10:51 +08:00
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__(
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
#loss = preds[2].asnumpy()[0]
#if len(self.losses)==20:
# print('ce loss', sum(self.losses)/len(self.losses))
# self.losses = []
#self.losses.append(loss)
preds = [preds[1]] #use softmax output
for label, pred_label in zip(labels, preds):
#print(pred_label)
#print(label.shape, pred_label.shape)
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy().astype('int32').flatten()
#print(label)
#print('label',label)
#print('pred_label', pred_label)
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
2017-12-13 22:07:07 +08:00
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__(
'lossvalue', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
loss = preds[-1].asnumpy()[0]
self.sum_metric += loss
self.num_inst += 1.0
2017-12-18 15:35:09 +08:00
gt_label = preds[-2].asnumpy()
#print(gt_label)
2017-12-13 22:07:07 +08:00
2017-11-14 15:10:51 +08:00
def parse_args():
parser = argparse.ArgumentParser(description='Train face network')
# general
2017-11-16 20:43:43 +08:00
parser.add_argument('--data-dir', default='',
help='')
2017-12-07 13:45:23 +08:00
parser.add_argument('--prefix', default='../model/model',
2017-11-14 15:10:51 +08:00
help='directory to save model.')
2017-12-11 10:05:10 +08:00
parser.add_argument('--pretrained', default='',
2017-11-14 15:10:51 +08:00
help='')
2017-12-11 10:05:10 +08:00
parser.add_argument('--retrain', action='store_true', default=False,
help='true means continue training.')
2017-12-19 23:23:41 +08:00
parser.add_argument('--ckpt', type=int, default=1, help='')
2017-12-02 22:16:22 +08:00
parser.add_argument('--network', default='s20', help='')
2017-12-23 19:08:22 +08:00
parser.add_argument('--version-se', type=int, default=0, help='')
2017-12-02 22:16:22 +08:00
parser.add_argument('--version-input', type=int, default=1, help='')
2017-12-11 10:05:10 +08:00
parser.add_argument('--version-output', type=str, default='E', help='')
parser.add_argument('--version-unit', type=int, default=3, help='')
2017-12-15 21:33:58 +08:00
parser.add_argument('--end-epoch', type=int, default=100000,
2017-11-14 15:10:51 +08:00
help='training epoch size.')
parser.add_argument('--lr', type=float, default=0.1,
help='')
2017-11-18 19:10:15 +08:00
parser.add_argument('--wd', type=float, default=0.0005,
help='')
2017-12-13 22:07:07 +08:00
parser.add_argument('--mom', type=float, default=0.9,
2017-11-14 15:10:51 +08:00
help='')
parser.add_argument('--embedding-dim', type=int, default=512,
help='')
parser.add_argument('--per-batch-size', type=int, default=0,
help='')
2018-01-06 19:37:54 +08:00
parser.add_argument('--margin-m', type=float, default=0.35,
help='')
parser.add_argument('--margin-s', type=float, default=64.0,
help='')
2017-11-14 15:10:51 +08:00
parser.add_argument('--margin', type=int, default=4,
help='')
parser.add_argument('--beta', type=float, default=1000.,
help='')
parser.add_argument('--beta-min', type=float, default=5.,
help='')
parser.add_argument('--beta-freeze', type=int, default=0,
help='')
parser.add_argument('--gamma', type=float, default=0.12,
help='')
parser.add_argument('--power', type=float, default=1.0,
help='')
parser.add_argument('--scale', type=float, default=0.9993,
help='')
2017-12-11 13:47:30 +08:00
parser.add_argument('--center-alpha', type=float, default=0.5, help='')
parser.add_argument('--center-scale', type=float, default=0.003, help='')
2017-12-12 22:55:18 +08:00
parser.add_argument('--images-per-identity', type=int, default=0, help='')
2017-12-18 15:35:09 +08:00
parser.add_argument('--triplet-bag-size', type=int, default=3600, help='')
2017-12-19 23:23:41 +08:00
parser.add_argument('--triplet-alpha', type=float, default=0.3, help='')
2018-01-03 23:55:25 +08:00
parser.add_argument('--triplet-max-ap', type=float, default=0.0, help='')
2017-12-11 13:47:30 +08:00
parser.add_argument('--verbose', type=int, default=2000, help='')
2017-11-14 15:10:51 +08:00
parser.add_argument('--loss-type', type=int, default=1,
help='')
2017-11-21 10:28:31 +08:00
parser.add_argument('--incay', type=float, default=0.0,
2017-11-16 14:29:48 +08:00
help='feature incay')
2017-11-14 15:10:51 +08:00
parser.add_argument('--use-deformable', type=int, default=0,
help='')
parser.add_argument('--patch', type=str, default='0_0_96_112_0',
help='')
parser.add_argument('--lr-steps', type=str, default='',
help='')
2017-12-15 21:33:58 +08:00
parser.add_argument('--target', type=str, default='lfw,cfp_ff,cfp_fp,agedb_30', help='')
2017-11-14 15:10:51 +08:00
args = parser.parse_args()
return args
def get_symbol(args, arg_params, aux_params):
2017-11-29 17:04:05 +08:00
data_shape = (args.image_channel,args.image_h,args.image_w)
2017-11-14 15:10:51 +08:00
image_shape = ",".join([str(x) for x in data_shape])
2017-12-07 16:43:28 +08:00
if args.network[0]=='d':
embedding = fdensenet.get_symbol(512, args.num_layers,
2017-12-11 10:05:10 +08:00
version_se=args.version_se, version_input=args.version_input,
2017-12-07 16:43:28 +08:00
version_output=args.version_output, version_unit=args.version_unit)
2017-11-14 15:10:51 +08:00
elif args.network[0]=='m':
2017-12-06 19:46:33 +08:00
print('init mobilenet', args.num_layers)
embedding = fmobilenet.get_symbol(512,
2017-12-11 10:05:10 +08:00
version_se=args.version_se, version_input=args.version_input,
2017-12-06 19:46:33 +08:00
version_output=args.version_output, version_unit=args.version_unit)
2017-11-14 15:10:51 +08:00
elif args.network[0]=='i':
2017-11-30 19:02:04 +08:00
print('init inception-resnet-v2', args.num_layers)
2017-12-09 17:14:24 +08:00
embedding = finception_resnet_v2.get_symbol(512,
2017-12-11 10:05:10 +08:00
version_se=args.version_se, version_input=args.version_input,
2017-12-09 17:14:24 +08:00
version_output=args.version_output, version_unit=args.version_unit)
2017-11-14 15:10:51 +08:00
elif args.network[0]=='x':
print('init xception', args.num_layers)
2017-12-09 16:55:52 +08:00
embedding = fxception.get_symbol(512,
2017-12-11 10:05:10 +08:00
version_se=args.version_se, version_input=args.version_input,
2017-12-07 15:09:34 +08:00
version_output=args.version_output, version_unit=args.version_unit)
2017-12-11 13:03:18 +08:00
elif args.network[0]=='p':
print('init dpn', args.num_layers)
embedding = fdpn.get_symbol(512, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
2017-12-23 19:08:22 +08:00
elif args.network[0]=='n':
print('init nasnet', args.num_layers)
embedding = fnasnet.get_symbol(512)
2017-11-14 15:10:51 +08:00
else:
print('init resnet', args.num_layers)
2017-12-02 22:16:22 +08:00
embedding = fresnet.get_symbol(512, args.num_layers,
2017-12-11 10:05:10 +08:00
version_se=args.version_se, version_input=args.version_input,
2017-12-02 22:16:22 +08:00
version_output=args.version_output, version_unit=args.version_unit)
2017-11-14 15:10:51 +08:00
gt_label = mx.symbol.Variable('softmax_label')
assert args.loss_type>=0
2017-11-16 14:29:48 +08:00
extra_loss = None
2018-01-04 12:25:16 +08:00
if args.loss_type==0: #softmax
2017-11-14 15:10:51 +08:00
_weight = mx.symbol.Variable('fc7_weight')
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
2018-01-04 12:25:16 +08:00
elif args.loss_type==1: #sphere
2017-11-14 15:10:51 +08:00
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
weight = _weight,
beta=args.beta, margin=args.margin, scale=args.scale,
2017-12-08 21:13:16 +08:00
beta_min=args.beta_min, verbose=1000, name='fc7')
2018-01-04 12:25:16 +08:00
elif args.loss_type==8: #centerloss, TODO
2017-12-09 21:34:27 +08:00
_weight = mx.symbol.Variable('fc7_weight')
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
2017-12-11 13:47:30 +08:00
print('center-loss', args.center_alpha, args.center_scale)
2017-12-09 21:34:27 +08:00
extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\
2017-12-11 13:47:30 +08:00
num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size)
2018-01-04 12:25:16 +08:00
elif args.loss_type==2:
2018-01-06 19:37:54 +08:00
s = args.margin_s
m = args.margin_m
2018-01-04 12:25:16 +08:00
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
2018-01-06 19:37:54 +08:00
if s>0.0:
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if m>0.0:
2018-01-06 19:59:19 +08:00
s_m = s*m
2018-01-06 19:37:54 +08:00
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
else:
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if m>0.0:
body = embedding*embedding
body = mx.sym.sum_axis(body, axis=1, keepdims=True)
body = mx.sym.sqrt(body)
body = body*m
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
2018-01-06 19:59:19 +08:00
body = mx.sym.broadcast_mul(gt_one_hot, body)
2018-01-06 20:52:06 +08:00
fc7 = fc7-body
2018-01-06 19:37:54 +08:00
2018-01-03 23:25:08 +08:00
elif args.loss_type==3:
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*22.0
fc7 = mx.sym.LSoftmax(data=nembedding, label=gt_label, num_hidden=args.num_classes,
weight = _weight,
beta=args.beta, margin=args.margin, scale=args.scale,
beta_min=args.beta_min, verbose=1000, name='fc7')
2018-01-11 09:10:02 +08:00
elif args.loss_type==4:
s = args.margin_s
m = args.margin_m
cos_m = math.cos(m)
sin_m = math.sin(m)
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
assert s>0.0
assert m>0.0
2018-01-11 09:36:29 +08:00
assert m<(math.pi/2)
threshold = 0.0
threshold = math.cos(math.pi-m)
2018-01-11 09:10:02 +08:00
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
2018-01-11 09:36:29 +08:00
if threshold==0.0:
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
else:
cond_v = cos_t - threshold
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
2018-01-11 09:10:02 +08:00
#theta = mx.sym.arccos(costheta)
#sintheta = mx.sym.sin(theta)
body = cos_t*cos_t
body = 1.0-body
sin_t = mx.sym.sqrt(body)
new_zy = cos_t*cos_m
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s
2018-01-11 19:54:23 +08:00
zy_keep = zy
_zy = sin_t*(-1.0*s*m)
zy_keep += _zy
new_zy = mx.sym.where(cond, new_zy, zy_keep)
2018-01-11 09:10:02 +08:00
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
#elif args.loss_type==4:
# s = args.margin_s
# m = args.margin_m
# assert s>0.0
# #assert m>0.0
# _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
# _weight = mx.symbol.L2Normalization(_weight, mode='instance')
# nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
# fc7 = mx.sym.AmSoftmax(data=nembedding, label=gt_label, num_hidden=args.num_classes,
# weight = _weight, verbose=1000,
# margin = m, s = s, name='fc7')
2017-12-13 22:07:07 +08:00
elif args.loss_type==10: #marginal loss
2017-11-14 15:10:51 +08:00
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
params = [1.2, 0.3, 1.0]
2017-12-15 21:33:58 +08:00
n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C
n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C
2017-11-14 15:10:51 +08:00
body = mx.sym.broadcast_sub(n1, n2) #N,N,C
body = body * body
body = mx.sym.sum(body, axis=2) # N,N
#body = mx.sym.sqrt(body)
body = body - params[0]
mask = mx.sym.Variable('extra')
body = body*mask
body = body+params[1]
#body = mx.sym.maximum(body, 0.0)
body = mx.symbol.Activation(data=body, act_type='relu')
body = mx.sym.sum(body)
body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size)
extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])
2017-12-13 22:07:07 +08:00
elif args.loss_type==11: #npair loss
2017-11-14 15:10:51 +08:00
params = [0.9, 0.2]
2017-12-13 22:07:07 +08:00
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
nembedding = mx.sym.transpose(nembedding)
nembedding = mx.symbol.reshape(nembedding, (512, args.per_identities, args.images_per_identity))
nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512
#nembedding = mx.symbol.reshape(nembedding, (512, args.images_per_identity, args.per_identities))
#nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512
n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1)
n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2)
#n1 = []
#n2 = []
#for i in xrange(args.per_identities):
# _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1)
# _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2)
# n1.append(_n1)
# n2.append(_n2)
#n1 = mx.sym.concat(*n1, dim=0)
#n2 = mx.sym.concat(*n2, dim=0)
#rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512))
#n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1)
#n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2)
n1 = mx.symbol.reshape(n1, (args.per_identities, 512))
n2 = mx.symbol.reshape(n2, (args.per_identities, 512))
cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pair
data_extra = mx.sym.Variable('extra')
data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities)
mask = cosine_matrix * data_extra
#body = mx.sym.mean(mask)
fii = mx.sym.sum_axis(mask, axis=1)
fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii)
fij_fii = mx.sym.exp(fij_fii)
row = mx.sym.sum_axis(fij_fii, axis=1)
row = mx.sym.log(row)
body = mx.sym.mean(row)
extra_loss = mx.sym.MakeLoss(body)
2017-12-18 15:35:09 +08:00
elif args.loss_type==12: #triplet loss
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)
positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)
negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)
ap = anchor - positive
an = anchor - negative
ap = ap*ap
an = an*an
ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')
triplet_loss = mx.symbol.mean(triplet_loss)
#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
extra_loss = mx.symbol.MakeLoss(triplet_loss)
2018-01-02 09:05:35 +08:00
elif args.loss_type==9: #coco loss
centroids = []
for i in xrange(args.per_identities):
xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity)
mean = mx.symbol.mean(xs, axis=0, keepdims=True)
mean = mx.symbol.L2Normalization(mean, mode='instance')
centroids.append(mean)
centroids = mx.symbol.concat(*centroids, dim=0)
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scale
fc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities)
#extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
#extra_loss = mx.symbol.BlockGrad(extra_loss)
2017-11-14 15:10:51 +08:00
else:
#embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
embedding = embedding * 5
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
weight = _weight,
beta=args.beta, margin=args.margin, scale=args.scale,
beta_min=args.beta_min, verbose=100, name='fc7')
#fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,
# beta=args.beta, margin=args.margin, scale=args.scale,
# op_type='ASoftmax', name='fc7')
2017-11-21 10:28:31 +08:00
if args.loss_type<=1 and args.incay>0.0:
params = [1.e-10]
2017-11-16 14:29:48 +08:00
sel = mx.symbol.argmax(data = fc7, axis=1)
sel = (sel==gt_label)
norm = embedding*embedding
norm = mx.symbol.sum(norm, axis=1)
2017-11-17 10:43:38 +08:00
norm = norm+params[0]
2017-11-16 14:29:48 +08:00
feature_incay = sel/norm
2017-11-21 10:28:31 +08:00
feature_incay = mx.symbol.mean(feature_incay) * args.incay
2017-11-16 14:29:48 +08:00
extra_loss = mx.symbol.MakeLoss(feature_incay)
2017-11-14 15:10:51 +08:00
#out = softmax
#l2_embedding = mx.symbol.L2Normalization(embedding)
#ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
#out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])
2017-12-12 22:55:18 +08:00
out_list = [mx.symbol.BlockGrad(embedding)]
2017-12-13 22:07:07 +08:00
softmax = None
2017-12-12 22:55:18 +08:00
if args.loss_type<10:
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
out_list.append(softmax)
2017-12-13 22:07:07 +08:00
if softmax is None:
out_list.append(mx.sym.BlockGrad(gt_label))
2017-11-16 14:29:48 +08:00
if extra_loss is not None:
2017-12-12 22:55:18 +08:00
out_list.append(extra_loss)
out = mx.symbol.Group(out_list)
2017-12-11 10:05:10 +08:00
return (out, arg_params, aux_params)
2017-11-14 15:10:51 +08:00
def train_net(args):
ctx = []
cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
if len(cvd)>0:
for i in xrange(len(cvd.split(','))):
ctx.append(mx.gpu(i))
if len(ctx)==0:
ctx = [mx.cpu()]
print('use cpu')
else:
print('gpu num:', len(ctx))
2017-12-07 13:45:23 +08:00
prefix = args.prefix
2017-12-07 12:27:33 +08:00
prefix_dir = os.path.dirname(prefix)
if not os.path.exists(prefix_dir):
os.makedirs(prefix_dir)
2017-11-14 15:10:51 +08:00
end_epoch = args.end_epoch
args.ctx_num = len(ctx)
args.num_layers = int(args.network[1:])
print('num_layers', args.num_layers)
if args.per_batch_size==0:
args.per_batch_size = 128
2017-12-15 21:33:58 +08:00
if args.loss_type==10:
args.per_batch_size = 256
2017-11-14 15:10:51 +08:00
args.batch_size = args.per_batch_size*args.ctx_num
args.rescale_threshold = 0
args.image_channel = 3
ppatch = [int(x) for x in args.patch.split('_')]
assert len(ppatch)==5
os.environ['BETA'] = str(args.beta)
2017-12-21 13:08:52 +08:00
data_dir_list = args.data_dir.split(',')
if args.loss_type!=12:
assert len(data_dir_list)==1
data_dir = data_dir_list[0]
2017-11-14 15:10:51 +08:00
args.use_val = False
path_imgrec = None
2017-11-16 20:43:43 +08:00
path_imglist = None
2017-11-14 15:10:51 +08:00
val_rec = None
2017-12-21 13:08:52 +08:00
prop = face_image.load_property(data_dir)
2017-12-08 13:54:45 +08:00
args.num_classes = prop.num_classes
image_size = prop.image_size
args.image_h = image_size[0]
args.image_w = image_size[1]
print('image_size', image_size)
2017-11-14 15:10:51 +08:00
2017-11-16 20:43:43 +08:00
assert(args.num_classes>0)
print('num_classes', args.num_classes)
2018-01-02 09:05:35 +08:00
args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3
2017-11-14 15:10:51 +08:00
2017-11-16 20:43:43 +08:00
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
2017-12-21 13:08:52 +08:00
path_imgrec = os.path.join(data_dir, "train.rec")
val_rec = os.path.join(data_dir, "val.rec")
2017-12-12 22:55:18 +08:00
if os.path.exists(val_rec) and args.loss_type<10:
2017-11-29 17:04:05 +08:00
args.use_val = True
else:
val_rec = None
2017-12-17 22:26:54 +08:00
args.use_val = False
2017-11-16 20:43:43 +08:00
#args.num_classes = 10572 #webface
#args.num_classes = 81017
#args.num_classes = 82395
2017-11-14 15:10:51 +08:00
2018-01-04 12:25:16 +08:00
if (args.loss_type>=1 and args.loss_type<=5) and args.num_classes>40000:
2017-11-14 15:10:51 +08:00
args.beta_freeze = 5000
args.gamma = 0.06
2018-01-02 09:05:35 +08:00
if args.loss_type<9:
2017-12-15 21:33:58 +08:00
assert args.images_per_identity==0
else:
2017-12-18 16:51:21 +08:00
if args.images_per_identity==0:
if args.loss_type==11:
args.images_per_identity = 2
2018-01-02 09:05:35 +08:00
elif args.loss_type==10 or args.loss_type==9:
2017-12-18 16:51:21 +08:00
args.images_per_identity = 16
elif args.loss_type==12:
args.images_per_identity = 5
assert args.per_batch_size%3==0
2017-12-15 21:33:58 +08:00
assert args.images_per_identity>=2
args.per_identities = int(args.per_batch_size/args.images_per_identity)
2017-12-13 22:07:07 +08:00
2017-11-14 15:10:51 +08:00
print('Called with argument:', args)
2017-11-29 17:04:05 +08:00
data_shape = (args.image_channel,image_size[0],image_size[1])
2017-11-21 10:28:31 +08:00
mean = None
2017-11-14 15:10:51 +08:00
2017-11-16 14:29:48 +08:00
if args.use_val:
2017-11-29 17:04:05 +08:00
val_dataiter = FaceImageIter(
2017-11-14 15:10:51 +08:00
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = val_rec,
2017-11-21 10:28:31 +08:00
#path_imglist = val_path,
2017-11-14 15:10:51 +08:00
shuffle = False,
rand_mirror = False,
mean = mean,
)
else:
val_dataiter = None
begin_epoch = 0
base_lr = args.lr
2017-11-18 19:10:15 +08:00
base_wd = args.wd
2017-12-13 22:07:07 +08:00
base_mom = args.mom
2017-12-11 13:03:18 +08:00
if len(args.pretrained)==0:
2017-11-14 15:10:51 +08:00
arg_params = None
aux_params = None
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
else:
2017-12-11 10:05:10 +08:00
vec = args.pretrained.split(',')
2017-12-18 15:35:09 +08:00
print('loading', vec)
2017-12-11 10:05:10 +08:00
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
2017-11-14 15:10:51 +08:00
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
2017-12-13 22:07:07 +08:00
data_extra = None
2017-12-15 21:33:58 +08:00
hard_mining = False
2017-12-18 15:35:09 +08:00
triplet_params = None
2018-01-02 09:05:35 +08:00
coco_mode = False
2017-12-13 22:07:07 +08:00
if args.loss_type==10:
2017-12-15 21:33:58 +08:00
hard_mining = True
_shape = (args.batch_size, args.per_batch_size)
2017-12-13 22:07:07 +08:00
data_extra = np.full(_shape, -1.0, dtype=np.float32)
c = 0
while c<args.batch_size:
a = 0
while a<args.per_batch_size:
b = a+args.images_per_identity
data_extra[(c+a):(c+b),a:b] = 1.0
#print(c+a, c+b, a, b)
a = b
c += args.per_batch_size
elif args.loss_type==11:
data_extra = np.zeros( (args.batch_size, args.per_identities), dtype=np.float32)
c = 0
while c<args.batch_size:
for i in xrange(args.per_identities):
data_extra[c+i][i] = 1.0
c+=args.per_batch_size
2017-12-18 15:35:09 +08:00
elif args.loss_type==12:
2018-01-03 23:55:25 +08:00
triplet_params = [args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap]
2018-01-02 09:05:35 +08:00
elif args.loss_type==9:
coco_mode = True
2017-12-13 22:07:07 +08:00
label_name = 'softmax_label'
if data_extra is None:
2017-11-14 15:10:51 +08:00
model = mx.mod.Module(
context = ctx,
symbol = sym,
)
else:
data_names = ('data', 'extra')
2017-12-13 22:07:07 +08:00
#label_name = ''
2017-11-14 15:10:51 +08:00
model = mx.mod.Module(
context = ctx,
symbol = sym,
data_names = data_names,
2017-12-13 22:07:07 +08:00
label_names = (label_name,),
2017-11-14 15:10:51 +08:00
)
2017-12-26 12:30:19 +08:00
if len(data_dir_list)==1 and args.loss_type!=12:
2017-12-21 13:08:52 +08:00
train_dataiter = FaceImageIter(
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = path_imgrec,
shuffle = True,
rand_mirror = True,
mean = mean,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_extra = data_extra,
hard_mining = hard_mining,
triplet_params = triplet_params,
2018-01-02 09:05:35 +08:00
coco_mode = coco_mode,
2017-12-21 13:08:52 +08:00
mx_model = model,
label_name = label_name,
)
else:
iter_list = []
for _data_dir in data_dir_list:
_path_imgrec = os.path.join(_data_dir, "train.rec")
_dataiter = FaceImageIter(
batch_size = args.batch_size,
data_shape = data_shape,
path_imgrec = _path_imgrec,
shuffle = True,
rand_mirror = True,
mean = mean,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_extra = data_extra,
hard_mining = hard_mining,
triplet_params = triplet_params,
2018-01-02 09:05:35 +08:00
coco_mode = coco_mode,
2017-12-21 13:08:52 +08:00
mx_model = model,
label_name = label_name,
)
iter_list.append(_dataiter)
iter_list.append(_dataiter)
train_dataiter = FaceImageIterList(iter_list)
2017-11-14 15:10:51 +08:00
2017-12-13 22:07:07 +08:00
if args.loss_type<10:
_metric = AccMetric()
else:
_metric = LossValueMetric()
eval_metrics = [mx.metric.create(_metric)]
2017-11-14 15:10:51 +08:00
if args.network[0]=='r':
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
2017-12-07 19:34:13 +08:00
elif args.network[0]=='i' or args.network[0]=='x':
2017-11-14 15:10:51 +08:00
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
else:
initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
2017-11-29 21:48:11 +08:00
_rescale = 1.0/args.ctx_num
2017-11-14 15:10:51 +08:00
opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
2017-12-18 15:35:09 +08:00
som = 20
if args.loss_type==12:
som = 2
_cb = mx.callback.Speedometer(args.batch_size, som)
2017-11-14 15:10:51 +08:00
2017-12-05 19:59:25 +08:00
ver_list = []
ver_name_list = []
2017-12-15 21:33:58 +08:00
for name in args.target.split(','):
2017-12-21 13:08:52 +08:00
path = os.path.join(data_dir,name+".bin")
2017-12-05 19:59:25 +08:00
if os.path.exists(path):
data_set = verification.load_bin(path, image_size)
ver_list.append(data_set)
ver_name_list.append(name)
print('ver', name)
2017-11-14 15:10:51 +08:00
2017-12-05 19:59:25 +08:00
def ver_test(nbatch):
results = []
for i in xrange(len(ver_list)):
2017-12-13 22:07:07 +08:00
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, data_extra)
2017-12-15 21:33:58 +08:00
print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
2017-12-08 21:13:16 +08:00
#print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
2017-12-05 19:59:25 +08:00
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
results.append(acc2)
return results
2017-11-14 15:10:51 +08:00
2017-11-21 10:28:31 +08:00
def val_test():
acc = AccMetric()
val_metric = mx.metric.create(acc)
val_metric.reset()
val_dataiter.reset()
for i, eval_batch in enumerate(val_dataiter):
model.forward(eval_batch, is_train=False)
model.update_metric(val_metric, eval_batch.label)
acc_value = val_metric.get_name_value()[0][1]
print('VACC: %f'%(acc_value))
2017-12-24 16:38:40 +08:00
highest_acc = [0.0, 0.0] #lfw and target
2017-12-17 22:26:54 +08:00
#for i in xrange(len(ver_list)):
# highest_acc.append(0.0)
2017-11-14 15:10:51 +08:00
global_step = [0]
save_step = [0]
if len(args.lr_steps)==0:
2017-12-17 22:26:54 +08:00
lr_steps = [40000, 60000, 80000]
2018-01-04 12:25:16 +08:00
if args.loss_type>=1 and args.loss_type<=5:
2017-12-17 22:26:54 +08:00
lr_steps = [100000, 140000, 160000]
2017-12-07 09:33:18 +08:00
p = 512.0/args.batch_size
2017-12-07 09:35:00 +08:00
for l in xrange(len(lr_steps)):
2017-12-07 09:33:18 +08:00
lr_steps[l] = int(lr_steps[l]*p)
2017-11-14 15:10:51 +08:00
else:
lr_steps = [int(x) for x in args.lr_steps.split(',')]
print('lr_steps', lr_steps)
def _batch_callback(param):
#global global_step
global_step[0]+=1
mbatch = global_step[0]
for _lr in lr_steps:
if mbatch==args.beta_freeze+_lr:
opt.lr *= 0.1
print('lr change to', opt.lr)
break
_cb(param)
if mbatch%1000==0:
print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
if mbatch>=0 and mbatch%args.verbose==0:
2017-12-05 19:59:25 +08:00
acc_list = ver_test(mbatch)
2017-11-14 15:10:51 +08:00
save_step[0]+=1
msave = save_step[0]
2017-11-16 14:29:48 +08:00
do_save = False
2017-12-07 14:37:27 +08:00
lfw_score = acc_list[0]
2017-12-24 16:38:40 +08:00
if lfw_score>highest_acc[0]:
highest_acc[0] = lfw_score
if lfw_score>=0.998:
do_save = True
if acc_list[-1]>=highest_acc[-1]:
highest_acc[-1] = acc_list[-1]
2017-12-17 22:26:54 +08:00
if lfw_score>=0.99:
do_save = True
2018-01-03 23:25:08 +08:00
if args.ckpt==0:
do_save = False
elif args.ckpt>1:
do_save = True
2017-12-17 22:26:54 +08:00
#for i in xrange(len(acc_list)):
# acc = acc_list[i]
# if acc>=highest_acc[i]:
# highest_acc[i] = acc
# if lfw_score>=0.99:
# do_save = True
#if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
# do_save = True
2018-01-03 23:25:08 +08:00
if do_save:
2017-12-18 09:25:43 +08:00
print('saving', msave)
2017-11-21 10:28:31 +08:00
if val_dataiter is not None:
val_test()
2017-11-16 14:29:48 +08:00
arg, aux = model.get_params()
mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
2017-12-05 19:59:25 +08:00
#if acc>=highest_acc[0]:
# lfw_npy = "%s-lfw-%04d" % (prefix, msave)
# X = np.concatenate(embeddings_list, axis=0)
# print('saving lfw npy', X.shape)
# np.save(lfw_npy, X)
2017-12-24 16:38:40 +08:00
print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
2017-11-14 15:10:51 +08:00
if mbatch<=args.beta_freeze:
_beta = args.beta
else:
move = max(0, mbatch-args.beta_freeze)
_beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
#print('beta', _beta)
os.environ['BETA'] = str(_beta)
#epoch_cb = mx.callback.do_checkpoint(prefix, 1)
epoch_cb = None
#def _epoch_callback(epoch, sym, arg_params, aux_params):
# print('epoch-end', epoch)
model.fit(train_dataiter,
begin_epoch = begin_epoch,
num_epoch = end_epoch,
eval_data = val_dataiter,
eval_metric = eval_metrics,
kvstore = 'device',
optimizer = opt,
#optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
allow_missing = True,
batch_end_callback = _batch_callback,
epoch_end_callback = epoch_cb )
def main():
#time.sleep(3600*6.5)
args = parse_args()
train_net(args)
if __name__ == '__main__':
main()