add fc7-no-bias option

This commit is contained in:
Jia Guo
2018-08-08 10:07:16 +08:00
parent a31f67187f
commit eeffa48eb1

View File

@@ -104,6 +104,7 @@ def parse_args():
parser.add_argument('--wd', type=float, default=0.0005, help='weight decay')
parser.add_argument('--fc7-wd-mult', type=float, default=1.0, help='weight decay mult for fc7')
parser.add_argument('--fc7-lr-mult', type=float, default=1.0, help='lr mult for fc7')
parser.add_argument("--fc7-no-bias", default=False, action="store_true" , help="fc7 no bias flag")
parser.add_argument('--bn-mom', type=float, default=0.9, help='bn mom')
parser.add_argument('--mom', type=float, default=0.9, help='momentum')
parser.add_argument('--emb-size', type=int, default=512, help='embedding length')
@@ -178,8 +179,11 @@ def get_symbol(args, arg_params, aux_params):
extra_loss = None
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
if args.loss_type==0: #softmax
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
if args.fc7_no_bias:
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
else:
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
elif args.loss_type==1: #sphere
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,