mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-18 06:38:19 +00:00
add margin stat
This commit is contained in:
@@ -47,8 +47,13 @@ class AccMetric(mx.metric.EvalMetric):
|
||||
'acc', axis=self.axis,
|
||||
output_names=None, label_names=None)
|
||||
self.losses = []
|
||||
self.count = 0
|
||||
|
||||
def update(self, labels, preds):
|
||||
if self.count%800==0 and len(preds)==4:
|
||||
a = preds[-2].asnumpy()[0]
|
||||
b = preds[-1].asnumpy()[0]
|
||||
print('[MARGIN]%f,%f'%(a,b))
|
||||
#loss = preds[2].asnumpy()[0]
|
||||
#if len(self.losses)==20:
|
||||
# print('ce loss', sum(self.losses)/len(self.losses))
|
||||
@@ -68,6 +73,7 @@ class AccMetric(mx.metric.EvalMetric):
|
||||
assert label.shape==pred_label.shape
|
||||
self.sum_metric += (pred_label.flat == label.flat).sum()
|
||||
self.num_inst += len(pred_label.flat)
|
||||
self.count+=1
|
||||
|
||||
class LossValueMetric(mx.metric.EvalMetric):
|
||||
def __init__(self):
|
||||
@@ -160,6 +166,7 @@ def parse_args():
|
||||
def get_symbol(args, arg_params, aux_params):
|
||||
data_shape = (args.image_channel,args.image_h,args.image_w)
|
||||
image_shape = ",".join([str(x) for x in data_shape])
|
||||
margin_symbols = []
|
||||
if args.network[0]=='d':
|
||||
embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
|
||||
version_se=args.version_se, version_input=args.version_input,
|
||||
@@ -225,9 +232,17 @@ def get_symbol(args, arg_params, aux_params):
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
if m>0.0:
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
|
||||
s_m = s*m
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
|
||||
fc7 = fc7-gt_one_hot
|
||||
|
||||
new_zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
new_cos_t = new_zy/s
|
||||
margin_symbols.append(mx.symbol.mean(new_cos_t))
|
||||
else:
|
||||
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
if m>0.0:
|
||||
@@ -242,24 +257,27 @@ def get_symbol(args, arg_params, aux_params):
|
||||
elif args.loss_type==3:
|
||||
s = args.margin_s
|
||||
m = args.margin_m
|
||||
assert m==2.0 or m==4.0
|
||||
assert args.margin==2 or args.margin==4
|
||||
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
#threshold = math.cos(math.pi/m)
|
||||
threshold = math.cos(1.0)
|
||||
threshold = math.cos(args.margin_m)
|
||||
cos_t = zy/s
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
cond_v = cos_t - threshold
|
||||
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
|
||||
body = cos_t
|
||||
for i in xrange(int(m/2)):
|
||||
for i in xrange(args.margin/2):
|
||||
body = body*body
|
||||
body = body*2-1
|
||||
new_zy = body*s
|
||||
zy_keep = zy
|
||||
new_zy = mx.sym.where(cond, new_zy, zy_keep)
|
||||
new_cos_t = new_zy/s
|
||||
margin_symbols.append(mx.symbol.mean(new_cos_t))
|
||||
diff = new_zy - zy
|
||||
diff = mx.sym.expand_dims(diff, 1)
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
@@ -286,6 +304,7 @@ def get_symbol(args, arg_params, aux_params):
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if args.easy_margin:
|
||||
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
|
||||
#cond_v = cos_t - 0.4
|
||||
@@ -307,6 +326,9 @@ def get_symbol(args, arg_params, aux_params):
|
||||
else:
|
||||
zy_keep = zy - s*mm
|
||||
new_zy = mx.sym.where(cond, new_zy, zy_keep)
|
||||
new_cos_t = new_zy/s
|
||||
margin_symbols.append(mx.symbol.mean(new_cos_t))
|
||||
|
||||
diff = new_zy - zy
|
||||
diff = mx.sym.expand_dims(diff, 1)
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
@@ -430,6 +452,9 @@ def get_symbol(args, arg_params, aux_params):
|
||||
out_list.append(mx.sym.BlockGrad(gt_label))
|
||||
if extra_loss is not None:
|
||||
out_list.append(extra_loss)
|
||||
for _sym in margin_symbols:
|
||||
_sym = mx.sym.BlockGrad(_sym)
|
||||
out_list.append(_sym)
|
||||
out = mx.symbol.Group(out_list)
|
||||
return (out, arg_params, aux_params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user