mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-15 21:23:52 +00:00
tiny fix
This commit is contained in:
@@ -53,18 +53,18 @@ class AccMetric(mx.metric.EvalMetric):
|
||||
|
||||
def update(self, labels, preds):
|
||||
self.count+=1
|
||||
preds = [preds[1]] #use softmax output
|
||||
for label, pred_label in zip(labels, preds):
|
||||
if pred_label.shape != label.shape:
|
||||
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
|
||||
pred_label = pred_label.asnumpy().astype('int32').flatten()
|
||||
label = label.asnumpy()
|
||||
if label.ndim==2:
|
||||
label = label[:,0]
|
||||
label = label.astype('int32').flatten()
|
||||
assert label.shape==pred_label.shape
|
||||
self.sum_metric += (pred_label.flat == label.flat).sum()
|
||||
self.num_inst += len(pred_label.flat)
|
||||
label = labels[0]
|
||||
pred_label = preds[1]
|
||||
if pred_label.shape != label.shape:
|
||||
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
|
||||
pred_label = pred_label.asnumpy().astype('int32').flatten()
|
||||
label = label.asnumpy()
|
||||
if label.ndim==2:
|
||||
label = label[:,0]
|
||||
label = label.astype('int32').flatten()
|
||||
assert label.shape==pred_label.shape
|
||||
self.sum_metric += (pred_label.flat == label.flat).sum()
|
||||
self.num_inst += len(pred_label.flat)
|
||||
|
||||
class LossValueMetric(mx.metric.EvalMetric):
|
||||
def __init__(self):
|
||||
@@ -277,6 +277,7 @@ def get_symbol(args, arg_params, aux_params):
|
||||
s = args.margin_s
|
||||
m = args.margin_m
|
||||
assert s>0.0
|
||||
assert args.margin_b>0.0
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
@@ -286,7 +287,7 @@ def get_symbol(args, arg_params, aux_params):
|
||||
intra_loss = t/np.pi
|
||||
intra_loss = mx.sym.mean(intra_loss)
|
||||
#intra_loss = mx.sym.exp(cos_t*-1.0)
|
||||
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 1)
|
||||
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = args.margin_b)
|
||||
if m>0.0:
|
||||
t = t+m
|
||||
body = mx.sym.cos(t)
|
||||
@@ -300,27 +301,36 @@ def get_symbol(args, arg_params, aux_params):
|
||||
s = args.margin_s
|
||||
m = args.margin_m
|
||||
assert s>0.0
|
||||
assert args.margin_b>0.0
|
||||
assert args.margin_a>0.0
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
t = mx.sym.arccos(cos_t)
|
||||
intra_loss = t/np.pi
|
||||
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 0.1)
|
||||
#idx = mx.sym.arange(0, args.num_classes)
|
||||
#idx = mx.sym.random.shuffle(idx)
|
||||
#idx = mx.sym.slice(data=idx, begin=0, end=100)
|
||||
#counter_weight = mx.sym.take(_weight, indices=idx, axis=1)
|
||||
#counter_weight = mx.sym.pick(_weight, gt_label, axis=1)
|
||||
counter_weight = mx.sym.take(_weight, gt_label, axis=1)
|
||||
counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
|
||||
counter_angle = mx.sym.arccos(counter_cos)
|
||||
inter_loss = counter_angle/np.pi #[0,1]
|
||||
inter_loss = inter_loss*-1.0 # [-1,0]
|
||||
inter_loss = inter_loss+1.0 # [0,1]
|
||||
|
||||
#counter_weight = mx.sym.take(_weight, gt_label, axis=1)
|
||||
#counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
|
||||
counter_weight = mx.sym.take(_weight, gt_label, axis=0)
|
||||
counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True)
|
||||
#counter_cos = mx.sym.minimum(counter_cos, 1.0)
|
||||
#counter_angle = mx.sym.arccos(counter_cos)
|
||||
#counter_angle = counter_angle * -1.0
|
||||
#counter_angle = counter_angle/np.pi #[0,1]
|
||||
#inter_loss = mx.sym.exp(counter_angle)
|
||||
|
||||
#counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True)
|
||||
#counter_cos = mx.sym.minimum(counter_cos, 1.0)
|
||||
#counter_angle = mx.sym.arccos(counter_cos)
|
||||
#counter_angle = mx.sym.sort(counter_angle, axis=1)
|
||||
#counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a))
|
||||
|
||||
#inter_loss = counter_angle*-1.0 # [-1,0]
|
||||
#inter_loss = inter_loss+1.0 # [0,1]
|
||||
inter_loss = counter_cos
|
||||
inter_loss = mx.sym.mean(inter_loss)
|
||||
inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = 0.1)
|
||||
inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = args.margin_b)
|
||||
if m>0.0:
|
||||
t = t+m
|
||||
body = mx.sym.cos(t)
|
||||
@@ -337,6 +347,7 @@ def get_symbol(args, arg_params, aux_params):
|
||||
out_list.append(intra_loss)
|
||||
if args.loss_type==7:
|
||||
out_list.append(inter_loss)
|
||||
#out_list.append(mx.sym.BlockGrad(counter_weight))
|
||||
#out_list.append(intra_loss)
|
||||
if args.ce_loss:
|
||||
#ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
|
||||
|
||||
Reference in New Issue
Block a user