diff --git a/src/train_softmax.py b/src/train_softmax.py index a222f29..4f0b92d 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -104,6 +104,7 @@ def parse_args(): parser.add_argument('--wd', type=float, default=0.0005, help='weight decay') parser.add_argument('--fc7-wd-mult', type=float, default=1.0, help='weight decay mult for fc7') parser.add_argument('--fc7-lr-mult', type=float, default=1.0, help='lr mult for fc7') + parser.add_argument("--fc7-no-bias", default=False, action="store_true" , help="fc7 no bias flag") parser.add_argument('--bn-mom', type=float, default=0.9, help='bn mom') parser.add_argument('--mom', type=float, default=0.9, help='momentum') parser.add_argument('--emb-size', type=int, default=512, help='embedding length') @@ -178,8 +179,11 @@ def get_symbol(args, arg_params, aux_params): extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) if args.loss_type==0: #softmax - _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) - fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') + if args.fc7_no_bias: + fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') + else: + _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) + fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,