diff --git a/recognition/ArcFace/train_parall.py b/recognition/ArcFace/train_parall.py index 1f50733..d1ecc82 100644 --- a/recognition/ArcFace/train_parall.py +++ b/recognition/ArcFace/train_parall.py @@ -94,7 +94,7 @@ def get_symbol_arcface(args): if config.loss_m1!=1.0 or config.loss_m2!=0.0 or config.loss_m3!=0.0: gt_one_hot = mx.sym.one_hot(gt_label, depth = args.ctx_num_classes, on_value = 1.0, off_value = 0.0) if config.loss_m1==1.0 and config.loss_m2==0.0: - _one_hot = gt_one_hot*args.margin_b + _one_hot = gt_one_hot*config.loss_m3 fc7 = fc7-_one_hot else: fc7_onehot = fc7 * gt_one_hot