add intra/inter loss

This commit is contained in:
nttstar
2018-11-11 23:01:22 +08:00
parent cfb96d21ec
commit fa9d3d175d
2 changed files with 95 additions and 10 deletions

View File

@@ -32,6 +32,7 @@ class FaceImageIter(io.DataIter):
path_imgrec = None,
shuffle=False, aug_list=None, mean = None,
rand_mirror = False, cutoff = 0, color_jittering = 0,
images_filter = 0,
data_name='data', label_name='softmax_label', **kwargs):
super(FaceImageIter, self).__init__()
assert path_imgrec
@@ -46,15 +47,19 @@ class FaceImageIter(io.DataIter):
print('header0 label', header.label)
self.header0 = (int(header.label[0]), int(header.label[1]))
#assert(header.flag==1)
self.imgidx = range(1, int(header.label[0]))
#self.imgidx = range(1, int(header.label[0]))
self.imgidx = []
self.id2range = {}
self.seq_identity = range(int(header.label[0]), int(header.label[1]))
for identity in self.seq_identity:
s = self.imgrec.read_idx(identity)
header, _ = recordio.unpack(s)
a,b = int(header.label[0]), int(header.label[1])
self.id2range[identity] = (a,b)
count = b-a
if count<images_filter:
continue
self.id2range[identity] = (a,b)
self.imgidx += range(a, b)
print('id2range', len(self.id2range))
else:
self.imgidx = list(self.imgrec.keys)

View File

@@ -125,6 +125,7 @@ def parse_args():
parser.add_argument('--rand-mirror', type=int, default=1, help='if do random mirror in training')
parser.add_argument('--cutoff', type=int, default=0, help='cut off aug')
parser.add_argument('--color', type=int, default=0, help='color jittering aug')
parser.add_argument('--images-filter', type=int, default=0, help='minimum images per identity filter')
parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30', help='verification targets')
parser.add_argument('--ce-loss', default=False, action='store_true', help='if output ce loss')
args = parser.parse_args()
@@ -272,9 +273,71 @@ def get_symbol(args, arg_params, aux_params):
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
elif args.loss_type==6:
s = args.margin_s
m = args.margin_m
assert s>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
intra_loss = t/np.pi
intra_loss = mx.sym.mean(intra_loss)
#intra_loss = mx.sym.exp(cos_t*-1.0)
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 1)
if m>0.0:
t = t+m
body = mx.sym.cos(t)
new_zy = body*s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
elif args.loss_type==7:
s = args.margin_s
m = args.margin_m
assert s>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
intra_loss = t/np.pi
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 0.1)
#idx = mx.sym.arange(0, args.num_classes)
#idx = mx.sym.random.shuffle(idx)
#idx = mx.sym.slice(data=idx, begin=0, end=100)
#counter_weight = mx.sym.take(_weight, indices=idx, axis=1)
#counter_weight = mx.sym.pick(_weight, gt_label, axis=1)
counter_weight = mx.sym.take(_weight, gt_label, axis=1)
counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
counter_angle = mx.sym.arccos(counter_cos)
inter_loss = counter_angle/np.pi #[0,1]
inter_loss = inter_loss*-1.0 # [-1,0]
inter_loss = inter_loss+1.0 # [0,1]
inter_loss = mx.sym.mean(inter_loss)
inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = 0.1)
if m>0.0:
t = t+m
body = mx.sym.cos(t)
new_zy = body*s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
out_list = [mx.symbol.BlockGrad(embedding)]
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
out_list.append(softmax)
if args.loss_type==6:
out_list.append(intra_loss)
if args.loss_type==7:
out_list.append(inter_loss)
#out_list.append(intra_loss)
if args.ce_loss:
#ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
body = mx.symbol.SoftmaxActivation(data=fc7)
@@ -346,6 +409,9 @@ def train_net(args):
arg_params = None
aux_params = None
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
if args.network[0]=='s':
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
else:
vec = args.pretrained.split(',')
print('loading', vec)
@@ -369,6 +435,7 @@ def train_net(args):
mean = mean,
cutoff = args.cutoff,
color_jittering = args.color,
images_filter = args.images_filter,
)
metric1 = AccMetric()
@@ -446,20 +513,33 @@ def train_net(args):
save_step[0]+=1
msave = save_step[0]
do_save = False
is_highest = False
if len(acc_list)>0:
lfw_score = acc_list[0]
if lfw_score>highest_acc[0]:
highest_acc[0] = lfw_score
if lfw_score>=0.998:
do_save = True
#lfw_score = acc_list[0]
#if lfw_score>highest_acc[0]:
# highest_acc[0] = lfw_score
# if lfw_score>=0.998:
# do_save = True
score = sum(acc_list)
if acc_list[-1]>=highest_acc[-1]:
if acc_list[-1]>highest_acc[-1]:
is_highest = True
else:
if score>=highest_acc[0]:
is_highest = True
highest_acc[0] = score
highest_acc[-1] = acc_list[-1]
if lfw_score>=0.99:
do_save = True
#if lfw_score>=0.99:
# do_save = True
if is_highest:
do_save = True
if args.ckpt==0:
do_save = False
elif args.ckpt>1:
elif args.ckpt==2:
do_save = True
elif args.ckpt==3:
msave = 1
if do_save:
print('saving', msave)
arg, aux = model.get_params()