Update train_parall.py

This commit is contained in:
JiankangDeng
2020-08-10 12:00:12 +01:00
committed by GitHub
parent bcc420c605
commit 6d5c538eea

View File

@@ -94,7 +94,7 @@ def get_symbol_arcface(args):
if config.loss_m1!=1.0 or config.loss_m2!=0.0 or config.loss_m3!=0.0:
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.ctx_num_classes, on_value = 1.0, off_value = 0.0)
if config.loss_m1==1.0 and config.loss_m2==0.0:
_one_hot = gt_one_hot*args.margin_b
_one_hot = gt_one_hot*config.loss_m3
fc7 = fc7-_one_hot
else:
fc7_onehot = fc7 * gt_one_hot