diff --git a/src/symbols/spherenet.py b/src/symbols/spherenet.py index 7405613..29e1fc7 100644 --- a/src/symbols/spherenet.py +++ b/src/symbols/spherenet.py @@ -7,20 +7,22 @@ def conv_main(data, units, filters, workspace): body = data for i in xrange(len(units)): f = filters[i] - _weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, 1), lr_mult=1.0) - _bias = mx.symbol.Variable("conv%d_%d_bias"%(i+1, 1), lr_mult=2.0, wd_mult=0.0) + _weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, 1), init=mx.init.Normal(0.01)) + _bias = mx.symbol.Variable("conv%d_%d_bias"%(i+1, 1), lr_mult=2.0, wd_mult=0.0, init=mx.init.Constant(0.0)) body = mx.sym.Convolution(data=body, weight = _weight, bias = _bias, num_filter=f, kernel=(3, 3), stride=(2,2), pad=(1, 1), name= "conv%d_%d"%(i+1, 1), workspace=workspace) body = mx.sym.LeakyReLU(data = body, act_type='prelu', name = "relu%d_%d" % (i+1, 1)) idx = 2 for j in xrange(units[i]): - _body = mx.sym.Convolution(data=body, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1), + _weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, idx), init=mx.init.Normal(0.01)) + _body = mx.sym.Convolution(data=body, weight=_weight, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1), name= "conv%d_%d"%(i+1, idx), workspace=workspace) _body = mx.sym.LeakyReLU(data = _body, act_type='prelu', name = "relu%d_%d" % (i+1, idx)) idx+=1 - _body = mx.sym.Convolution(data=_body, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1), + _weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, idx), init=mx.init.Normal(0.01)) + _body = mx.sym.Convolution(data=_body, weight=_weight, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1), name= "conv%d_%d"%(i+1, idx), workspace=workspace) _body = mx.sym.LeakyReLU(data = _body, act_type='prelu', name = "relu%d_%d" % (i+1, idx)) idx+=1 @@ -57,30 +59,4 @@ def get_symbol(num_classes, num_layers, conv_workspace=256, **kwargs): fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='fc1') return fc1 -def init_weights(sym, data_shape_dict, num_layers): - arg_name = sym.list_arguments() - aux_name = sym.list_auxiliary_states() - arg_shape, aaa, aux_shape = sym.infer_shape(**data_shape_dict) - #print(data_shape_dict) - #print(arg_name) - #print(arg_shape) - arg_params = {} - aux_params = None - #print(aaa) - #print(aux_shape) - arg_shape_dict = dict(zip(arg_name, arg_shape)) - aux_shape_dict = dict(zip(aux_name, aux_shape)) - #print(aux_shape) - #print(aux_params) - #print(arg_shape_dict) - for k,v in arg_shape_dict.iteritems(): - if k.startswith('conv') and k.endswith('_weight'): - if not k.find('_1_')>=0: - if num_layers<100: - arg_params[k] = mx.random.normal(0, 0.01, shape=v) - print('init', k) - if k.endswith('_bias'): - arg_params[k] = mx.nd.zeros(shape=v) - print('init', k) - return arg_params, aux_params diff --git a/src/train.py b/src/train.py index 9b627b7..da87b29 100644 --- a/src/train.py +++ b/src/train.py @@ -174,7 +174,7 @@ def parse_args(): parser.add_argument('--patch', type=str, default='0_0_96_112_0',help='') parser.add_argument('--lr-steps', type=str, default='', help='') parser.add_argument('--max-steps', type=int, default=0, help='') - parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30', help='') + parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30,cplfw,calfw', help='') args = parser.parse_args() return args @@ -386,69 +386,35 @@ def get_symbol(args, arg_params, aux_params): body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==5: - #s = args.margin_s - #m = args.margin_m - #assert s>0.0 - #assert m>=0.0 - #assert m<(math.pi/2) - #_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) - #_weight = mx.symbol.L2Normalization(_weight, mode='instance') - #nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s - #fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') - #zy = mx.sym.pick(fc7, gt_label, axis=1) - #cos_t = zy/s - #if args.margin_verbose>0: - # margin_symbols.append(mx.symbol.mean(cos_t)) - #if m>0.0: - # a1 = args.margin_a - # r1 = ta-a1 - # r1 = mx.symbol.Activation(data=r1, act_type='relu') - # r1 = r1+a1 - # t = mx.sym.arccos(cos_t) - # cond = t-1.0 - # cond = mx.symbol.Activation(data=cond, act_type='relu') - # r = mx.sym.where(cond, r2, r1) - # t = t+var_m - # body = mx.sym.cos(t) - # new_zy = body*s - # if args.margin_verbose>0: - # new_cos_t = new_zy/s - # margin_symbols.append(mx.symbol.mean(new_cos_t)) - # #margin_symbols.append(mx.symbol.mean(var_m)) - # diff = new_zy - zy - # diff = mx.sym.expand_dims(diff, 1) - # gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) - # body = mx.sym.broadcast_mul(gt_one_hot, diff) - # fc7 = fc7+body s = args.margin_s m = args.margin_m assert s>0.0 - #assert m>=0.0 - #assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') - zy = mx.sym.pick(fc7, gt_label, axis=1) - cos_t = zy/s - t = mx.sym.arccos(cos_t) - if args.margin_verbose>0: - margin_symbols.append(mx.symbol.mean(t)) - if args.margin_a>0.0: - t = t*args.margin_a - if args.margin_m>0.0: - t = t+args.margin_m - body = mx.sym.cos(t) - if args.margin_b>0.0: - body = body - args.margin_b - new_zy = body*s - if args.margin_verbose>0: - margin_symbols.append(mx.symbol.mean(t)) - diff = new_zy - zy - diff = mx.sym.expand_dims(diff, 1) - gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) - body = mx.sym.broadcast_mul(gt_one_hot, diff) - fc7 = fc7+body + if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: + if args.margin_a==1.0 and args.margin_m==0.0: + s_m = s*args.margin_b + gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) + fc7 = fc7-gt_one_hot + else: + zy = mx.sym.pick(fc7, gt_label, axis=1) + cos_t = zy/s + t = mx.sym.arccos(cos_t) + if args.margin_a!=1.0: + t = t*args.margin_a + if args.margin_m>0.0: + t = t+args.margin_m + body = mx.sym.cos(t) + if args.margin_b>0.0: + body = body - args.margin_b + new_zy = body*s + diff = new_zy - zy + diff = mx.sym.expand_dims(diff, 1) + gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) + body = mx.sym.broadcast_mul(gt_one_hot, diff) + fc7 = fc7+body elif args.loss_type==6: s = args.margin_s m = args.margin_m @@ -505,7 +471,7 @@ def get_symbol(args, arg_params, aux_params): t = mx.sym.arccos(cos_t) if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(t)) - var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,)) + var_m = mx.sym.random.uniform(low=args.margin_b, high=args.margin_m, shape=(1,)) t = mx.sym.broadcast_add(t,var_m) body = mx.sym.cos(t) new_zy = body*s @@ -775,9 +741,6 @@ def train_net(args): print('loading', vec) _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1])) sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) - if args.network[0]=='s': - data_shape_dict = {'data' : (args.per_batch_size,)+data_shape} - spherenet.init_weights(sym, data_shape_dict, args.num_layers) data_extra = None hard_mining = False diff --git a/src/train_softmax.py b/src/train_softmax.py index 4f0b92d..8d281ef 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -337,9 +337,6 @@ def train_net(args): print('loading', vec) _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1])) sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) - if args.network[0]=='s': - data_shape_dict = {'data' : (args.per_batch_size,)+data_shape} - spherenet.init_weights(sym, data_shape_dict, args.num_layers) #label_name = 'softmax_label' #label_shape = (args.batch_size,) diff --git a/src/train_triplet.py b/src/train_triplet.py index bfac082..ac87750 100644 --- a/src/train_triplet.py +++ b/src/train_triplet.py @@ -221,9 +221,6 @@ def train_net(args): arg_params = None aux_params = None sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) - if args.network[0]=='s': - data_shape_dict = {'data' : (args.per_batch_size,)+data_shape} - spherenet.init_weights(sym, data_shape_dict, args.num_layers) else: vec = args.pretrained.split(',') print('loading', vec)