init spherenet weight in symbol def

This commit is contained in:
Jia Guo
2018-08-17 16:41:56 +08:00
parent 0ba6d3d47e
commit ec9aa4d6de
4 changed files with 30 additions and 97 deletions

View File

@@ -7,20 +7,22 @@ def conv_main(data, units, filters, workspace):
body = data
for i in xrange(len(units)):
f = filters[i]
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, 1), lr_mult=1.0)
_bias = mx.symbol.Variable("conv%d_%d_bias"%(i+1, 1), lr_mult=2.0, wd_mult=0.0)
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, 1), init=mx.init.Normal(0.01))
_bias = mx.symbol.Variable("conv%d_%d_bias"%(i+1, 1), lr_mult=2.0, wd_mult=0.0, init=mx.init.Constant(0.0))
body = mx.sym.Convolution(data=body, weight = _weight, bias = _bias, num_filter=f, kernel=(3, 3), stride=(2,2), pad=(1, 1),
name= "conv%d_%d"%(i+1, 1), workspace=workspace)
body = mx.sym.LeakyReLU(data = body, act_type='prelu', name = "relu%d_%d" % (i+1, 1))
idx = 2
for j in xrange(units[i]):
_body = mx.sym.Convolution(data=body, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, idx), init=mx.init.Normal(0.01))
_body = mx.sym.Convolution(data=body, weight=_weight, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
name= "conv%d_%d"%(i+1, idx), workspace=workspace)
_body = mx.sym.LeakyReLU(data = _body, act_type='prelu', name = "relu%d_%d" % (i+1, idx))
idx+=1
_body = mx.sym.Convolution(data=_body, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, idx), init=mx.init.Normal(0.01))
_body = mx.sym.Convolution(data=_body, weight=_weight, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
name= "conv%d_%d"%(i+1, idx), workspace=workspace)
_body = mx.sym.LeakyReLU(data = _body, act_type='prelu', name = "relu%d_%d" % (i+1, idx))
idx+=1
@@ -57,30 +59,4 @@ def get_symbol(num_classes, num_layers, conv_workspace=256, **kwargs):
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='fc1')
return fc1
def init_weights(sym, data_shape_dict, num_layers):
arg_name = sym.list_arguments()
aux_name = sym.list_auxiliary_states()
arg_shape, aaa, aux_shape = sym.infer_shape(**data_shape_dict)
#print(data_shape_dict)
#print(arg_name)
#print(arg_shape)
arg_params = {}
aux_params = None
#print(aaa)
#print(aux_shape)
arg_shape_dict = dict(zip(arg_name, arg_shape))
aux_shape_dict = dict(zip(aux_name, aux_shape))
#print(aux_shape)
#print(aux_params)
#print(arg_shape_dict)
for k,v in arg_shape_dict.iteritems():
if k.startswith('conv') and k.endswith('_weight'):
if not k.find('_1_')>=0:
if num_layers<100:
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
print('init', k)
if k.endswith('_bias'):
arg_params[k] = mx.nd.zeros(shape=v)
print('init', k)
return arg_params, aux_params

View File

@@ -174,7 +174,7 @@ def parse_args():
parser.add_argument('--patch', type=str, default='0_0_96_112_0',help='')
parser.add_argument('--lr-steps', type=str, default='', help='')
parser.add_argument('--max-steps', type=int, default=0, help='')
parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30', help='')
parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30,cplfw,calfw', help='')
args = parser.parse_args()
return args
@@ -386,69 +386,35 @@ def get_symbol(args, arg_params, aux_params):
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
elif args.loss_type==5:
#s = args.margin_s
#m = args.margin_m
#assert s>0.0
#assert m>=0.0
#assert m<(math.pi/2)
#_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
#_weight = mx.symbol.L2Normalization(_weight, mode='instance')
#nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
#fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
#zy = mx.sym.pick(fc7, gt_label, axis=1)
#cos_t = zy/s
#if args.margin_verbose>0:
# margin_symbols.append(mx.symbol.mean(cos_t))
#if m>0.0:
# a1 = args.margin_a
# r1 = ta-a1
# r1 = mx.symbol.Activation(data=r1, act_type='relu')
# r1 = r1+a1
# t = mx.sym.arccos(cos_t)
# cond = t-1.0
# cond = mx.symbol.Activation(data=cond, act_type='relu')
# r = mx.sym.where(cond, r2, r1)
# t = t+var_m
# body = mx.sym.cos(t)
# new_zy = body*s
# if args.margin_verbose>0:
# new_cos_t = new_zy/s
# margin_symbols.append(mx.symbol.mean(new_cos_t))
# #margin_symbols.append(mx.symbol.mean(var_m))
# diff = new_zy - zy
# diff = mx.sym.expand_dims(diff, 1)
# gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
# body = mx.sym.broadcast_mul(gt_one_hot, diff)
# fc7 = fc7+body
s = args.margin_s
m = args.margin_m
assert s>0.0
#assert m>=0.0
#assert m<(math.pi/2)
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(t))
if args.margin_a>0.0:
t = t*args.margin_a
if args.margin_m>0.0:
t = t+args.margin_m
body = mx.sym.cos(t)
if args.margin_b>0.0:
body = body - args.margin_b
new_zy = body*s
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(t))
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
if args.margin_a==1.0 and args.margin_m==0.0:
s_m = s*args.margin_b
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
else:
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
if args.margin_a!=1.0:
t = t*args.margin_a
if args.margin_m>0.0:
t = t+args.margin_m
body = mx.sym.cos(t)
if args.margin_b>0.0:
body = body - args.margin_b
new_zy = body*s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
elif args.loss_type==6:
s = args.margin_s
m = args.margin_m
@@ -505,7 +471,7 @@ def get_symbol(args, arg_params, aux_params):
t = mx.sym.arccos(cos_t)
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(t))
var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,))
var_m = mx.sym.random.uniform(low=args.margin_b, high=args.margin_m, shape=(1,))
t = mx.sym.broadcast_add(t,var_m)
body = mx.sym.cos(t)
new_zy = body*s
@@ -775,9 +741,6 @@ def train_net(args):
print('loading', vec)
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
if args.network[0]=='s':
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
data_extra = None
hard_mining = False

View File

@@ -337,9 +337,6 @@ def train_net(args):
print('loading', vec)
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
if args.network[0]=='s':
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
#label_name = 'softmax_label'
#label_shape = (args.batch_size,)

View File

@@ -221,9 +221,6 @@ def train_net(args):
arg_params = None
aux_params = None
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
if args.network[0]=='s':
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
else:
vec = args.pretrained.split(',')
print('loading', vec)