From 7d53fc82a4c5f20ab2ca49d6080929d465b8aec2 Mon Sep 17 00:00:00 2001 From: nttstar Date: Fri, 16 Nov 2018 14:49:26 +0800 Subject: [PATCH] tiny fix --- src/train_softmax.py | 65 ++++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/src/train_softmax.py b/src/train_softmax.py index 5716ae7..3e34591 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -53,18 +53,18 @@ class AccMetric(mx.metric.EvalMetric): def update(self, labels, preds): self.count+=1 - preds = [preds[1]] #use softmax output - for label, pred_label in zip(labels, preds): - if pred_label.shape != label.shape: - pred_label = mx.ndarray.argmax(pred_label, axis=self.axis) - pred_label = pred_label.asnumpy().astype('int32').flatten() - label = label.asnumpy() - if label.ndim==2: - label = label[:,0] - label = label.astype('int32').flatten() - assert label.shape==pred_label.shape - self.sum_metric += (pred_label.flat == label.flat).sum() - self.num_inst += len(pred_label.flat) + label = labels[0] + pred_label = preds[1] + if pred_label.shape != label.shape: + pred_label = mx.ndarray.argmax(pred_label, axis=self.axis) + pred_label = pred_label.asnumpy().astype('int32').flatten() + label = label.asnumpy() + if label.ndim==2: + label = label[:,0] + label = label.astype('int32').flatten() + assert label.shape==pred_label.shape + self.sum_metric += (pred_label.flat == label.flat).sum() + self.num_inst += len(pred_label.flat) class LossValueMetric(mx.metric.EvalMetric): def __init__(self): @@ -277,6 +277,7 @@ def get_symbol(args, arg_params, aux_params): s = args.margin_s m = args.margin_m assert s>0.0 + assert args.margin_b>0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') @@ -286,7 +287,7 @@ def get_symbol(args, arg_params, aux_params): intra_loss = t/np.pi intra_loss = mx.sym.mean(intra_loss) #intra_loss = mx.sym.exp(cos_t*-1.0) - intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 1) + intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = args.margin_b) if m>0.0: t = t+m body = mx.sym.cos(t) @@ -300,27 +301,36 @@ def get_symbol(args, arg_params, aux_params): s = args.margin_s m = args.margin_m assert s>0.0 + assert args.margin_b>0.0 + assert args.margin_a>0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) - intra_loss = t/np.pi - intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 0.1) - #idx = mx.sym.arange(0, args.num_classes) - #idx = mx.sym.random.shuffle(idx) - #idx = mx.sym.slice(data=idx, begin=0, end=100) - #counter_weight = mx.sym.take(_weight, indices=idx, axis=1) - #counter_weight = mx.sym.pick(_weight, gt_label, axis=1) - counter_weight = mx.sym.take(_weight, gt_label, axis=1) - counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True) - counter_angle = mx.sym.arccos(counter_cos) - inter_loss = counter_angle/np.pi #[0,1] - inter_loss = inter_loss*-1.0 # [-1,0] - inter_loss = inter_loss+1.0 # [0,1] + + #counter_weight = mx.sym.take(_weight, gt_label, axis=1) + #counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True) + counter_weight = mx.sym.take(_weight, gt_label, axis=0) + counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True) + #counter_cos = mx.sym.minimum(counter_cos, 1.0) + #counter_angle = mx.sym.arccos(counter_cos) + #counter_angle = counter_angle * -1.0 + #counter_angle = counter_angle/np.pi #[0,1] + #inter_loss = mx.sym.exp(counter_angle) + + #counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True) + #counter_cos = mx.sym.minimum(counter_cos, 1.0) + #counter_angle = mx.sym.arccos(counter_cos) + #counter_angle = mx.sym.sort(counter_angle, axis=1) + #counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a)) + + #inter_loss = counter_angle*-1.0 # [-1,0] + #inter_loss = inter_loss+1.0 # [0,1] + inter_loss = counter_cos inter_loss = mx.sym.mean(inter_loss) - inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = 0.1) + inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = args.margin_b) if m>0.0: t = t+m body = mx.sym.cos(t) @@ -337,6 +347,7 @@ def get_symbol(args, arg_params, aux_params): out_list.append(intra_loss) if args.loss_type==7: out_list.append(inter_loss) + #out_list.append(mx.sym.BlockGrad(counter_weight)) #out_list.append(intra_loss) if args.ce_loss: #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size