mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-17 14:26:08 +00:00
tiny
This commit is contained in:
67
src/data.py
67
src/data.py
@@ -77,20 +77,20 @@ class FaceImageIter(io.DataIter):
|
||||
self.idx2flag = {}
|
||||
self.idx2meancos = {}
|
||||
self.c2c_auto = False
|
||||
if output_c2c or c2c_threshold>0.0 or c2c_mode>=-5:
|
||||
path_c2c = os.path.join(os.path.dirname(path_imgrec), 'c2c')
|
||||
print(path_c2c)
|
||||
if os.path.exists(path_c2c):
|
||||
for line in open(path_c2c, 'r'):
|
||||
vec = line.strip().split(',')
|
||||
idx = int(vec[0])
|
||||
self.idx2cos[idx] = float(vec[1])
|
||||
self.idx2flag[idx] = 1
|
||||
if len(vec)>2:
|
||||
self.idx2flag[idx] = int(vec[2])
|
||||
else:
|
||||
self.c2c_auto = True
|
||||
self.c2c_step = 10000
|
||||
#if output_c2c or c2c_threshold>0.0 or c2c_mode>=-5:
|
||||
# path_c2c = os.path.join(os.path.dirname(path_imgrec), 'c2c')
|
||||
# print(path_c2c)
|
||||
# if os.path.exists(path_c2c):
|
||||
# for line in open(path_c2c, 'r'):
|
||||
# vec = line.strip().split(',')
|
||||
# idx = int(vec[0])
|
||||
# self.idx2cos[idx] = float(vec[1])
|
||||
# self.idx2flag[idx] = 1
|
||||
# if len(vec)>2:
|
||||
# self.idx2flag[idx] = int(vec[2])
|
||||
# else:
|
||||
# self.c2c_auto = True
|
||||
# self.c2c_step = 10000
|
||||
if header.flag>0:
|
||||
print('header0 label', header.label)
|
||||
self.header0 = (int(header.label[0]), int(header.label[1]))
|
||||
@@ -166,10 +166,10 @@ class FaceImageIter(io.DataIter):
|
||||
s = self.imgrec.read_idx(identity)
|
||||
header, _ = recordio.unpack(s)
|
||||
a,b = int(header.label[0]), int(header.label[1])
|
||||
#print('flag', header.flag)
|
||||
#print(header.label)
|
||||
#assert(header.flag==2)
|
||||
self.id2range[identity] = (a,b)
|
||||
count = b-a
|
||||
for ii in xrange(a,b):
|
||||
self.idx2flag[ii] = count
|
||||
if len(self.idx2cos)>0:
|
||||
m = 0.0
|
||||
for ii in xrange(a,b):
|
||||
@@ -180,7 +180,7 @@ class FaceImageIter(io.DataIter):
|
||||
#self.idx2meancos[identity] = m
|
||||
|
||||
print('id2range', len(self.id2range))
|
||||
print(len(self.idx2cos), len(self.idx2meancos))
|
||||
print(len(self.idx2cos), len(self.idx2meancos), len(self.idx2flag))
|
||||
else:
|
||||
self.imgidx = list(self.imgrec.keys)
|
||||
if shuffle:
|
||||
@@ -743,8 +743,12 @@ class FaceImageIter(io.DataIter):
|
||||
header, img = recordio.unpack(s)
|
||||
label = header.label
|
||||
if self.output_c2c:
|
||||
meancos = self.idx2meancos[idx]
|
||||
label = [label, meancos]
|
||||
#v = self.idx2meancos[idx]
|
||||
v = 0.5
|
||||
count = self.idx2flag[idx]
|
||||
if count>=self.output_c2c:
|
||||
v = 0.4
|
||||
label = [label, v]
|
||||
else:
|
||||
if not isinstance(label, numbers.Number):
|
||||
label = label[0]
|
||||
@@ -869,26 +873,17 @@ class FaceImageIter(io.DataIter):
|
||||
for ll in xrange(batch_label.shape[1]):
|
||||
v = label[ll]
|
||||
if ll>0:
|
||||
c2c = v
|
||||
#m = min(0.55, max(0.3,math.log(c2c+1)*4-1.85))
|
||||
#c2c = v
|
||||
#_param = [0.5, 0.4, 0.85, 0.75]
|
||||
#_a = (_param[1]-_param[0])/(_param[3]-_param[2])
|
||||
#m = _param[1]+_a*(c2c-_param[3])
|
||||
#m = min(_param[0], max(_param[1],m))
|
||||
#v = math.cos(m)
|
||||
#v = v*v
|
||||
#_param = [0.5, 0.3, 0.85, 0.7]
|
||||
_param = [0.5, 0.4, 0.85, 0.75]
|
||||
#_param = [0.55, 0.4, 0.9, 0.75]
|
||||
_a = (_param[1]-_param[0])/(_param[3]-_param[2])
|
||||
m = _param[1]+_a*(c2c-_param[3])
|
||||
m = min(_param[0], max(_param[1],m))
|
||||
#m = 0.5
|
||||
#if c2c<0.77:
|
||||
# m = 0.3
|
||||
#elif c2c<0.82:
|
||||
# m = 0.4
|
||||
#elif c2c>0.88:
|
||||
# m = 0.55
|
||||
m = v
|
||||
v = math.cos(m)
|
||||
v = v*v
|
||||
#print('c2c', i,c2c,m,v)
|
||||
#print('m', i,m,v)
|
||||
|
||||
batch_label[i][ll] = v
|
||||
else:
|
||||
|
||||
@@ -347,14 +347,14 @@ def get_symbol(args, arg_params, aux_params):
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if args.output_c2c==0:
|
||||
cos_m = math.cos(m)
|
||||
sin_m = math.sin(m)
|
||||
mm = math.sin(math.pi-m)*m
|
||||
#threshold = 0.0
|
||||
threshold = math.cos(math.pi-m)
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if args.easy_margin:
|
||||
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
|
||||
else:
|
||||
@@ -415,23 +415,15 @@ def get_symbol(args, arg_params, aux_params):
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if m>0.0:
|
||||
#m = m*1.1
|
||||
#m_min = 0.3
|
||||
#var_m = m
|
||||
#cos_ta = mx.symbol.Activation(data=cos_t, act_type='relu')
|
||||
#cos_ta = cos_t + 1.001
|
||||
cos_ta = cos_t - 0.7
|
||||
cos_ta = mx.symbol.Activation(data=cos_ta, act_type='relu')
|
||||
#cos_t_max = mx.symbol.max(cos_ta)
|
||||
#cos_t_min = mx.symbol.min(cos_ta)
|
||||
#cos_t_gap = cos_t_max-cos_t_min
|
||||
#cos_t_max = cos_t_max + 1.0e-6
|
||||
#r = mx.symbol.broadcast_div(cos_ta,cos_t_max)
|
||||
#r = cos_ta / 1.7
|
||||
r = cos_ta+0.7
|
||||
var_m = r*m
|
||||
|
||||
a1 = args.margin_a
|
||||
r1 = ta-a1
|
||||
r1 = mx.symbol.Activation(data=r1, act_type='relu')
|
||||
r1 = r1+a1
|
||||
t = mx.sym.arccos(cos_t)
|
||||
cond = t-1.0
|
||||
cond = mx.symbol.Activation(data=cond, act_type='relu')
|
||||
r = mx.sym.where(cond, r2, r1)
|
||||
t = t+var_m
|
||||
body = mx.sym.cos(t)
|
||||
new_zy = body*s
|
||||
@@ -467,10 +459,7 @@ def get_symbol(args, arg_params, aux_params):
|
||||
r1 = mx.symbol.Activation(data=r1, act_type='relu')
|
||||
r1 = r1+a1
|
||||
|
||||
a2 = 1.0
|
||||
r2 = ta-a2
|
||||
r2 = mx.symbol.Activation(data=r2, act_type='relu')
|
||||
r2 = r2+a2
|
||||
r2 = mx.symbol.zeros(shape=(args.per_batch_size,))
|
||||
|
||||
cond = t-1.0
|
||||
cond = mx.symbol.Activation(data=cond, act_type='relu')
|
||||
@@ -503,8 +492,8 @@ def get_symbol(args, arg_params, aux_params):
|
||||
t = mx.sym.arccos(cos_t)
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(t))
|
||||
var_m = mx.sym.random.uniform(low=0.4, high=0.5, shape=(1,))
|
||||
t = t+var_m
|
||||
var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,))
|
||||
t = mx.sym.broadcast_add(t,var_m)
|
||||
body = mx.sym.cos(t)
|
||||
new_zy = body*s
|
||||
if args.margin_verbose>0:
|
||||
|
||||
Reference in New Issue
Block a user