This commit is contained in:
nttstar
2018-02-25 18:55:28 +08:00
parent 95a4986eaa
commit a2d5fd554c
2 changed files with 43 additions and 59 deletions

View File

@@ -347,14 +347,14 @@ def get_symbol(args, arg_params, aux_params):
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(cos_t))
if args.output_c2c==0:
cos_m = math.cos(m)
sin_m = math.sin(m)
mm = math.sin(math.pi-m)*m
#threshold = 0.0
threshold = math.cos(math.pi-m)
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(cos_t))
if args.easy_margin:
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
else:
@@ -415,23 +415,15 @@ def get_symbol(args, arg_params, aux_params):
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(cos_t))
if m>0.0:
#m = m*1.1
#m_min = 0.3
#var_m = m
#cos_ta = mx.symbol.Activation(data=cos_t, act_type='relu')
#cos_ta = cos_t + 1.001
cos_ta = cos_t - 0.7
cos_ta = mx.symbol.Activation(data=cos_ta, act_type='relu')
#cos_t_max = mx.symbol.max(cos_ta)
#cos_t_min = mx.symbol.min(cos_ta)
#cos_t_gap = cos_t_max-cos_t_min
#cos_t_max = cos_t_max + 1.0e-6
#r = mx.symbol.broadcast_div(cos_ta,cos_t_max)
#r = cos_ta / 1.7
r = cos_ta+0.7
var_m = r*m
a1 = args.margin_a
r1 = ta-a1
r1 = mx.symbol.Activation(data=r1, act_type='relu')
r1 = r1+a1
t = mx.sym.arccos(cos_t)
cond = t-1.0
cond = mx.symbol.Activation(data=cond, act_type='relu')
r = mx.sym.where(cond, r2, r1)
t = t+var_m
body = mx.sym.cos(t)
new_zy = body*s
@@ -467,10 +459,7 @@ def get_symbol(args, arg_params, aux_params):
r1 = mx.symbol.Activation(data=r1, act_type='relu')
r1 = r1+a1
a2 = 1.0
r2 = ta-a2
r2 = mx.symbol.Activation(data=r2, act_type='relu')
r2 = r2+a2
r2 = mx.symbol.zeros(shape=(args.per_batch_size,))
cond = t-1.0
cond = mx.symbol.Activation(data=cond, act_type='relu')
@@ -503,8 +492,8 @@ def get_symbol(args, arg_params, aux_params):
t = mx.sym.arccos(cos_t)
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(t))
var_m = mx.sym.random.uniform(low=0.4, high=0.5, shape=(1,))
t = t+var_m
var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,))
t = mx.sym.broadcast_add(t,var_m)
body = mx.sym.cos(t)
new_zy = body*s
if args.margin_verbose>0: