diff --git a/src/train_softmax.py b/src/train_softmax.py index 18ae4c5..e0edb23 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -249,7 +249,7 @@ def get_symbol(args, arg_params, aux_params): zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) - if args.margin_a>0.0: + if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m