add margin stat

This commit is contained in:
Jia Guo
2018-01-21 15:30:39 +08:00
parent c33e285aad
commit b3216fc7ec

View File

@@ -47,8 +47,13 @@ class AccMetric(mx.metric.EvalMetric):
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
if self.count%800==0 and len(preds)==4:
a = preds[-2].asnumpy()[0]
b = preds[-1].asnumpy()[0]
print('[MARGIN]%f,%f'%(a,b))
#loss = preds[2].asnumpy()[0]
#if len(self.losses)==20:
# print('ce loss', sum(self.losses)/len(self.losses))
@@ -68,6 +73,7 @@ class AccMetric(mx.metric.EvalMetric):
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
self.count+=1
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
@@ -160,6 +166,7 @@ def parse_args():
def get_symbol(args, arg_params, aux_params):
data_shape = (args.image_channel,args.image_h,args.image_w)
image_shape = ",".join([str(x) for x in data_shape])
margin_symbols = []
if args.network[0]=='d':
embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
@@ -225,9 +232,17 @@ def get_symbol(args, arg_params, aux_params):
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if m>0.0:
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
margin_symbols.append(mx.symbol.mean(cos_t))
s_m = s*m
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
new_zy = mx.sym.pick(fc7, gt_label, axis=1)
new_cos_t = new_zy/s
margin_symbols.append(mx.symbol.mean(new_cos_t))
else:
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if m>0.0:
@@ -242,24 +257,27 @@ def get_symbol(args, arg_params, aux_params):
elif args.loss_type==3:
s = args.margin_s
m = args.margin_m
assert m==2.0 or m==4.0
assert args.margin==2 or args.margin==4
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
#threshold = math.cos(math.pi/m)
threshold = math.cos(1.0)
threshold = math.cos(args.margin_m)
cos_t = zy/s
margin_symbols.append(mx.symbol.mean(cos_t))
cond_v = cos_t - threshold
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
body = cos_t
for i in xrange(int(m/2)):
for i in xrange(args.margin/2):
body = body*body
body = body*2-1
new_zy = body*s
zy_keep = zy
new_zy = mx.sym.where(cond, new_zy, zy_keep)
new_cos_t = new_zy/s
margin_symbols.append(mx.symbol.mean(new_cos_t))
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
@@ -286,6 +304,7 @@ def get_symbol(args, arg_params, aux_params):
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
margin_symbols.append(mx.symbol.mean(cos_t))
if args.easy_margin:
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
#cond_v = cos_t - 0.4
@@ -307,6 +326,9 @@ def get_symbol(args, arg_params, aux_params):
else:
zy_keep = zy - s*mm
new_zy = mx.sym.where(cond, new_zy, zy_keep)
new_cos_t = new_zy/s
margin_symbols.append(mx.symbol.mean(new_cos_t))
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
@@ -430,6 +452,9 @@ def get_symbol(args, arg_params, aux_params):
out_list.append(mx.sym.BlockGrad(gt_label))
if extra_loss is not None:
out_list.append(extra_loss)
for _sym in margin_symbols:
_sym = mx.sym.BlockGrad(_sym)
out_list.append(_sym)
out = mx.symbol.Group(out_list)
return (out, arg_params, aux_params)