mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-17 22:27:54 +00:00
add fc7-no-bias option
This commit is contained in:
@@ -104,6 +104,7 @@ def parse_args():
|
||||
parser.add_argument('--wd', type=float, default=0.0005, help='weight decay')
|
||||
parser.add_argument('--fc7-wd-mult', type=float, default=1.0, help='weight decay mult for fc7')
|
||||
parser.add_argument('--fc7-lr-mult', type=float, default=1.0, help='lr mult for fc7')
|
||||
parser.add_argument("--fc7-no-bias", default=False, action="store_true" , help="fc7 no bias flag")
|
||||
parser.add_argument('--bn-mom', type=float, default=0.9, help='bn mom')
|
||||
parser.add_argument('--mom', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--emb-size', type=int, default=512, help='embedding length')
|
||||
@@ -178,8 +179,11 @@ def get_symbol(args, arg_params, aux_params):
|
||||
extra_loss = None
|
||||
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
|
||||
if args.loss_type==0: #softmax
|
||||
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
|
||||
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
|
||||
if args.fc7_no_bias:
|
||||
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
else:
|
||||
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
|
||||
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
|
||||
elif args.loss_type==1: #sphere
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
|
||||
|
||||
Reference in New Issue
Block a user