mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-17 06:12:02 +00:00
Update train_parall.py
This commit is contained in:
@@ -94,7 +94,7 @@ def get_symbol_arcface(args):
|
||||
if config.loss_m1!=1.0 or config.loss_m2!=0.0 or config.loss_m3!=0.0:
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.ctx_num_classes, on_value = 1.0, off_value = 0.0)
|
||||
if config.loss_m1==1.0 and config.loss_m2==0.0:
|
||||
_one_hot = gt_one_hot*args.margin_b
|
||||
_one_hot = gt_one_hot*config.loss_m3
|
||||
fc7 = fc7-_one_hot
|
||||
else:
|
||||
fc7_onehot = fc7 * gt_one_hot
|
||||
|
||||
Reference in New Issue
Block a user