mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-22 09:37:48 +00:00
init spherenet weight in symbol def
This commit is contained in:
@@ -7,20 +7,22 @@ def conv_main(data, units, filters, workspace):
|
||||
body = data
|
||||
for i in xrange(len(units)):
|
||||
f = filters[i]
|
||||
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, 1), lr_mult=1.0)
|
||||
_bias = mx.symbol.Variable("conv%d_%d_bias"%(i+1, 1), lr_mult=2.0, wd_mult=0.0)
|
||||
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, 1), init=mx.init.Normal(0.01))
|
||||
_bias = mx.symbol.Variable("conv%d_%d_bias"%(i+1, 1), lr_mult=2.0, wd_mult=0.0, init=mx.init.Constant(0.0))
|
||||
body = mx.sym.Convolution(data=body, weight = _weight, bias = _bias, num_filter=f, kernel=(3, 3), stride=(2,2), pad=(1, 1),
|
||||
name= "conv%d_%d"%(i+1, 1), workspace=workspace)
|
||||
|
||||
body = mx.sym.LeakyReLU(data = body, act_type='prelu', name = "relu%d_%d" % (i+1, 1))
|
||||
idx = 2
|
||||
for j in xrange(units[i]):
|
||||
_body = mx.sym.Convolution(data=body, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
|
||||
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, idx), init=mx.init.Normal(0.01))
|
||||
_body = mx.sym.Convolution(data=body, weight=_weight, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
|
||||
name= "conv%d_%d"%(i+1, idx), workspace=workspace)
|
||||
|
||||
_body = mx.sym.LeakyReLU(data = _body, act_type='prelu', name = "relu%d_%d" % (i+1, idx))
|
||||
idx+=1
|
||||
_body = mx.sym.Convolution(data=_body, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
|
||||
_weight = mx.symbol.Variable("conv%d_%d_weight"%(i+1, idx), init=mx.init.Normal(0.01))
|
||||
_body = mx.sym.Convolution(data=_body, weight=_weight, no_bias=True, num_filter=f, kernel=(3, 3), stride=(1,1), pad=(1, 1),
|
||||
name= "conv%d_%d"%(i+1, idx), workspace=workspace)
|
||||
_body = mx.sym.LeakyReLU(data = _body, act_type='prelu', name = "relu%d_%d" % (i+1, idx))
|
||||
idx+=1
|
||||
@@ -57,30 +59,4 @@ def get_symbol(num_classes, num_layers, conv_workspace=256, **kwargs):
|
||||
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='fc1')
|
||||
return fc1
|
||||
|
||||
def init_weights(sym, data_shape_dict, num_layers):
|
||||
arg_name = sym.list_arguments()
|
||||
aux_name = sym.list_auxiliary_states()
|
||||
arg_shape, aaa, aux_shape = sym.infer_shape(**data_shape_dict)
|
||||
#print(data_shape_dict)
|
||||
#print(arg_name)
|
||||
#print(arg_shape)
|
||||
arg_params = {}
|
||||
aux_params = None
|
||||
#print(aaa)
|
||||
#print(aux_shape)
|
||||
arg_shape_dict = dict(zip(arg_name, arg_shape))
|
||||
aux_shape_dict = dict(zip(aux_name, aux_shape))
|
||||
#print(aux_shape)
|
||||
#print(aux_params)
|
||||
#print(arg_shape_dict)
|
||||
for k,v in arg_shape_dict.iteritems():
|
||||
if k.startswith('conv') and k.endswith('_weight'):
|
||||
if not k.find('_1_')>=0:
|
||||
if num_layers<100:
|
||||
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
|
||||
print('init', k)
|
||||
if k.endswith('_bias'):
|
||||
arg_params[k] = mx.nd.zeros(shape=v)
|
||||
print('init', k)
|
||||
return arg_params, aux_params
|
||||
|
||||
|
||||
85
src/train.py
85
src/train.py
@@ -174,7 +174,7 @@ def parse_args():
|
||||
parser.add_argument('--patch', type=str, default='0_0_96_112_0',help='')
|
||||
parser.add_argument('--lr-steps', type=str, default='', help='')
|
||||
parser.add_argument('--max-steps', type=int, default=0, help='')
|
||||
parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30', help='')
|
||||
parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30,cplfw,calfw', help='')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -386,69 +386,35 @@ def get_symbol(args, arg_params, aux_params):
|
||||
body = mx.sym.broadcast_mul(gt_one_hot, diff)
|
||||
fc7 = fc7+body
|
||||
elif args.loss_type==5:
|
||||
#s = args.margin_s
|
||||
#m = args.margin_m
|
||||
#assert s>0.0
|
||||
#assert m>=0.0
|
||||
#assert m<(math.pi/2)
|
||||
#_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
|
||||
#_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
#nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
#fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
#zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
#cos_t = zy/s
|
||||
#if args.margin_verbose>0:
|
||||
# margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
#if m>0.0:
|
||||
# a1 = args.margin_a
|
||||
# r1 = ta-a1
|
||||
# r1 = mx.symbol.Activation(data=r1, act_type='relu')
|
||||
# r1 = r1+a1
|
||||
# t = mx.sym.arccos(cos_t)
|
||||
# cond = t-1.0
|
||||
# cond = mx.symbol.Activation(data=cond, act_type='relu')
|
||||
# r = mx.sym.where(cond, r2, r1)
|
||||
# t = t+var_m
|
||||
# body = mx.sym.cos(t)
|
||||
# new_zy = body*s
|
||||
# if args.margin_verbose>0:
|
||||
# new_cos_t = new_zy/s
|
||||
# margin_symbols.append(mx.symbol.mean(new_cos_t))
|
||||
# #margin_symbols.append(mx.symbol.mean(var_m))
|
||||
# diff = new_zy - zy
|
||||
# diff = mx.sym.expand_dims(diff, 1)
|
||||
# gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
# body = mx.sym.broadcast_mul(gt_one_hot, diff)
|
||||
# fc7 = fc7+body
|
||||
s = args.margin_s
|
||||
m = args.margin_m
|
||||
assert s>0.0
|
||||
#assert m>=0.0
|
||||
#assert m<(math.pi/2)
|
||||
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
t = mx.sym.arccos(cos_t)
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(t))
|
||||
if args.margin_a>0.0:
|
||||
t = t*args.margin_a
|
||||
if args.margin_m>0.0:
|
||||
t = t+args.margin_m
|
||||
body = mx.sym.cos(t)
|
||||
if args.margin_b>0.0:
|
||||
body = body - args.margin_b
|
||||
new_zy = body*s
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(t))
|
||||
diff = new_zy - zy
|
||||
diff = mx.sym.expand_dims(diff, 1)
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
body = mx.sym.broadcast_mul(gt_one_hot, diff)
|
||||
fc7 = fc7+body
|
||||
if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
|
||||
if args.margin_a==1.0 and args.margin_m==0.0:
|
||||
s_m = s*args.margin_b
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
|
||||
fc7 = fc7-gt_one_hot
|
||||
else:
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
t = mx.sym.arccos(cos_t)
|
||||
if args.margin_a!=1.0:
|
||||
t = t*args.margin_a
|
||||
if args.margin_m>0.0:
|
||||
t = t+args.margin_m
|
||||
body = mx.sym.cos(t)
|
||||
if args.margin_b>0.0:
|
||||
body = body - args.margin_b
|
||||
new_zy = body*s
|
||||
diff = new_zy - zy
|
||||
diff = mx.sym.expand_dims(diff, 1)
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
body = mx.sym.broadcast_mul(gt_one_hot, diff)
|
||||
fc7 = fc7+body
|
||||
elif args.loss_type==6:
|
||||
s = args.margin_s
|
||||
m = args.margin_m
|
||||
@@ -505,7 +471,7 @@ def get_symbol(args, arg_params, aux_params):
|
||||
t = mx.sym.arccos(cos_t)
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(t))
|
||||
var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,))
|
||||
var_m = mx.sym.random.uniform(low=args.margin_b, high=args.margin_m, shape=(1,))
|
||||
t = mx.sym.broadcast_add(t,var_m)
|
||||
body = mx.sym.cos(t)
|
||||
new_zy = body*s
|
||||
@@ -775,9 +741,6 @@ def train_net(args):
|
||||
print('loading', vec)
|
||||
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
|
||||
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
|
||||
if args.network[0]=='s':
|
||||
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
|
||||
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
|
||||
|
||||
data_extra = None
|
||||
hard_mining = False
|
||||
|
||||
@@ -337,9 +337,6 @@ def train_net(args):
|
||||
print('loading', vec)
|
||||
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
|
||||
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
|
||||
if args.network[0]=='s':
|
||||
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
|
||||
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
|
||||
|
||||
#label_name = 'softmax_label'
|
||||
#label_shape = (args.batch_size,)
|
||||
|
||||
@@ -221,9 +221,6 @@ def train_net(args):
|
||||
arg_params = None
|
||||
aux_params = None
|
||||
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
|
||||
if args.network[0]=='s':
|
||||
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
|
||||
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
|
||||
else:
|
||||
vec = args.pretrained.split(',')
|
||||
print('loading', vec)
|
||||
|
||||
Reference in New Issue
Block a user