This commit is contained in:
nttstar
2018-11-16 14:49:26 +08:00
parent a0db189133
commit 7d53fc82a4

View File

@@ -53,18 +53,18 @@ class AccMetric(mx.metric.EvalMetric):
def update(self, labels, preds):
self.count+=1
preds = [preds[1]] #use softmax output
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim==2:
label = label[:,0]
label = label.astype('int32').flatten()
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
label = labels[0]
pred_label = preds[1]
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim==2:
label = label[:,0]
label = label.astype('int32').flatten()
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
@@ -277,6 +277,7 @@ def get_symbol(args, arg_params, aux_params):
s = args.margin_s
m = args.margin_m
assert s>0.0
assert args.margin_b>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
@@ -286,7 +287,7 @@ def get_symbol(args, arg_params, aux_params):
intra_loss = t/np.pi
intra_loss = mx.sym.mean(intra_loss)
#intra_loss = mx.sym.exp(cos_t*-1.0)
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 1)
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = args.margin_b)
if m>0.0:
t = t+m
body = mx.sym.cos(t)
@@ -300,27 +301,36 @@ def get_symbol(args, arg_params, aux_params):
s = args.margin_s
m = args.margin_m
assert s>0.0
assert args.margin_b>0.0
assert args.margin_a>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
intra_loss = t/np.pi
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 0.1)
#idx = mx.sym.arange(0, args.num_classes)
#idx = mx.sym.random.shuffle(idx)
#idx = mx.sym.slice(data=idx, begin=0, end=100)
#counter_weight = mx.sym.take(_weight, indices=idx, axis=1)
#counter_weight = mx.sym.pick(_weight, gt_label, axis=1)
counter_weight = mx.sym.take(_weight, gt_label, axis=1)
counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
counter_angle = mx.sym.arccos(counter_cos)
inter_loss = counter_angle/np.pi #[0,1]
inter_loss = inter_loss*-1.0 # [-1,0]
inter_loss = inter_loss+1.0 # [0,1]
#counter_weight = mx.sym.take(_weight, gt_label, axis=1)
#counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
counter_weight = mx.sym.take(_weight, gt_label, axis=0)
counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True)
#counter_cos = mx.sym.minimum(counter_cos, 1.0)
#counter_angle = mx.sym.arccos(counter_cos)
#counter_angle = counter_angle * -1.0
#counter_angle = counter_angle/np.pi #[0,1]
#inter_loss = mx.sym.exp(counter_angle)
#counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True)
#counter_cos = mx.sym.minimum(counter_cos, 1.0)
#counter_angle = mx.sym.arccos(counter_cos)
#counter_angle = mx.sym.sort(counter_angle, axis=1)
#counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a))
#inter_loss = counter_angle*-1.0 # [-1,0]
#inter_loss = inter_loss+1.0 # [0,1]
inter_loss = counter_cos
inter_loss = mx.sym.mean(inter_loss)
inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = 0.1)
inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = args.margin_b)
if m>0.0:
t = t+m
body = mx.sym.cos(t)
@@ -337,6 +347,7 @@ def get_symbol(args, arg_params, aux_params):
out_list.append(intra_loss)
if args.loss_type==7:
out_list.append(inter_loss)
#out_list.append(mx.sym.BlockGrad(counter_weight))
#out_list.append(intra_loss)
if args.ce_loss:
#ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size