From b3216fc7ecfbd2b73c3de059b277286dc751a5cd Mon Sep 17 00:00:00 2001 From: Jia Guo Date: Sun, 21 Jan 2018 15:30:39 +0800 Subject: [PATCH] add margin stat --- src/train_softmax.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/train_softmax.py b/src/train_softmax.py index b7463e8..8220ef0 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -47,8 +47,13 @@ class AccMetric(mx.metric.EvalMetric): 'acc', axis=self.axis, output_names=None, label_names=None) self.losses = [] + self.count = 0 def update(self, labels, preds): + if self.count%800==0 and len(preds)==4: + a = preds[-2].asnumpy()[0] + b = preds[-1].asnumpy()[0] + print('[MARGIN]%f,%f'%(a,b)) #loss = preds[2].asnumpy()[0] #if len(self.losses)==20: # print('ce loss', sum(self.losses)/len(self.losses)) @@ -68,6 +73,7 @@ class AccMetric(mx.metric.EvalMetric): assert label.shape==pred_label.shape self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat) + self.count+=1 class LossValueMetric(mx.metric.EvalMetric): def __init__(self): @@ -160,6 +166,7 @@ def parse_args(): def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) + margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, @@ -225,9 +232,17 @@ def get_symbol(args, arg_params, aux_params): nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if m>0.0: + zy = mx.sym.pick(fc7, gt_label, axis=1) + cos_t = zy/s + margin_symbols.append(mx.symbol.mean(cos_t)) + s_m = s*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot + + new_zy = mx.sym.pick(fc7, gt_label, axis=1) + new_cos_t = new_zy/s + margin_symbols.append(mx.symbol.mean(new_cos_t)) else: fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if m>0.0: @@ -242,24 +257,27 @@ def get_symbol(args, arg_params, aux_params): elif args.loss_type==3: s = args.margin_s m = args.margin_m - assert m==2.0 or m==4.0 + assert args.margin==2 or args.margin==4 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) #threshold = math.cos(math.pi/m) - threshold = math.cos(1.0) + threshold = math.cos(args.margin_m) cos_t = zy/s + margin_symbols.append(mx.symbol.mean(cos_t)) cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t - for i in xrange(int(m/2)): + for i in xrange(args.margin/2): body = body*body body = body*2-1 new_zy = body*s zy_keep = zy new_zy = mx.sym.where(cond, new_zy, zy_keep) + new_cos_t = new_zy/s + margin_symbols.append(mx.symbol.mean(new_cos_t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) @@ -286,6 +304,7 @@ def get_symbol(args, arg_params, aux_params): fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s + margin_symbols.append(mx.symbol.mean(cos_t)) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') #cond_v = cos_t - 0.4 @@ -307,6 +326,9 @@ def get_symbol(args, arg_params, aux_params): else: zy_keep = zy - s*mm new_zy = mx.sym.where(cond, new_zy, zy_keep) + new_cos_t = new_zy/s + margin_symbols.append(mx.symbol.mean(new_cos_t)) + diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) @@ -430,6 +452,9 @@ def get_symbol(args, arg_params, aux_params): out_list.append(mx.sym.BlockGrad(gt_label)) if extra_loss is not None: out_list.append(extra_loss) + for _sym in margin_symbols: + _sym = mx.sym.BlockGrad(_sym) + out_list.append(_sym) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)