mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-19 07:27:52 +00:00
tiny
This commit is contained in:
@@ -347,14 +347,14 @@ def get_symbol(args, arg_params, aux_params):
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if args.output_c2c==0:
|
||||
cos_m = math.cos(m)
|
||||
sin_m = math.sin(m)
|
||||
mm = math.sin(math.pi-m)*m
|
||||
#threshold = 0.0
|
||||
threshold = math.cos(math.pi-m)
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if args.easy_margin:
|
||||
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
|
||||
else:
|
||||
@@ -415,23 +415,15 @@ def get_symbol(args, arg_params, aux_params):
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if m>0.0:
|
||||
#m = m*1.1
|
||||
#m_min = 0.3
|
||||
#var_m = m
|
||||
#cos_ta = mx.symbol.Activation(data=cos_t, act_type='relu')
|
||||
#cos_ta = cos_t + 1.001
|
||||
cos_ta = cos_t - 0.7
|
||||
cos_ta = mx.symbol.Activation(data=cos_ta, act_type='relu')
|
||||
#cos_t_max = mx.symbol.max(cos_ta)
|
||||
#cos_t_min = mx.symbol.min(cos_ta)
|
||||
#cos_t_gap = cos_t_max-cos_t_min
|
||||
#cos_t_max = cos_t_max + 1.0e-6
|
||||
#r = mx.symbol.broadcast_div(cos_ta,cos_t_max)
|
||||
#r = cos_ta / 1.7
|
||||
r = cos_ta+0.7
|
||||
var_m = r*m
|
||||
|
||||
a1 = args.margin_a
|
||||
r1 = ta-a1
|
||||
r1 = mx.symbol.Activation(data=r1, act_type='relu')
|
||||
r1 = r1+a1
|
||||
t = mx.sym.arccos(cos_t)
|
||||
cond = t-1.0
|
||||
cond = mx.symbol.Activation(data=cond, act_type='relu')
|
||||
r = mx.sym.where(cond, r2, r1)
|
||||
t = t+var_m
|
||||
body = mx.sym.cos(t)
|
||||
new_zy = body*s
|
||||
@@ -467,10 +459,7 @@ def get_symbol(args, arg_params, aux_params):
|
||||
r1 = mx.symbol.Activation(data=r1, act_type='relu')
|
||||
r1 = r1+a1
|
||||
|
||||
a2 = 1.0
|
||||
r2 = ta-a2
|
||||
r2 = mx.symbol.Activation(data=r2, act_type='relu')
|
||||
r2 = r2+a2
|
||||
r2 = mx.symbol.zeros(shape=(args.per_batch_size,))
|
||||
|
||||
cond = t-1.0
|
||||
cond = mx.symbol.Activation(data=cond, act_type='relu')
|
||||
@@ -503,8 +492,8 @@ def get_symbol(args, arg_params, aux_params):
|
||||
t = mx.sym.arccos(cos_t)
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(t))
|
||||
var_m = mx.sym.random.uniform(low=0.4, high=0.5, shape=(1,))
|
||||
t = t+var_m
|
||||
var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,))
|
||||
t = mx.sym.broadcast_add(t,var_m)
|
||||
body = mx.sym.cos(t)
|
||||
new_zy = body*s
|
||||
if args.margin_verbose>0:
|
||||
|
||||
Reference in New Issue
Block a user