This commit is contained in:
nttstar
2018-02-25 18:55:28 +08:00
parent 95a4986eaa
commit a2d5fd554c
2 changed files with 43 additions and 59 deletions

View File

@@ -77,20 +77,20 @@ class FaceImageIter(io.DataIter):
self.idx2flag = {}
self.idx2meancos = {}
self.c2c_auto = False
if output_c2c or c2c_threshold>0.0 or c2c_mode>=-5:
path_c2c = os.path.join(os.path.dirname(path_imgrec), 'c2c')
print(path_c2c)
if os.path.exists(path_c2c):
for line in open(path_c2c, 'r'):
vec = line.strip().split(',')
idx = int(vec[0])
self.idx2cos[idx] = float(vec[1])
self.idx2flag[idx] = 1
if len(vec)>2:
self.idx2flag[idx] = int(vec[2])
else:
self.c2c_auto = True
self.c2c_step = 10000
#if output_c2c or c2c_threshold>0.0 or c2c_mode>=-5:
# path_c2c = os.path.join(os.path.dirname(path_imgrec), 'c2c')
# print(path_c2c)
# if os.path.exists(path_c2c):
# for line in open(path_c2c, 'r'):
# vec = line.strip().split(',')
# idx = int(vec[0])
# self.idx2cos[idx] = float(vec[1])
# self.idx2flag[idx] = 1
# if len(vec)>2:
# self.idx2flag[idx] = int(vec[2])
# else:
# self.c2c_auto = True
# self.c2c_step = 10000
if header.flag>0:
print('header0 label', header.label)
self.header0 = (int(header.label[0]), int(header.label[1]))
@@ -166,10 +166,10 @@ class FaceImageIter(io.DataIter):
s = self.imgrec.read_idx(identity)
header, _ = recordio.unpack(s)
a,b = int(header.label[0]), int(header.label[1])
#print('flag', header.flag)
#print(header.label)
#assert(header.flag==2)
self.id2range[identity] = (a,b)
count = b-a
for ii in xrange(a,b):
self.idx2flag[ii] = count
if len(self.idx2cos)>0:
m = 0.0
for ii in xrange(a,b):
@@ -180,7 +180,7 @@ class FaceImageIter(io.DataIter):
#self.idx2meancos[identity] = m
print('id2range', len(self.id2range))
print(len(self.idx2cos), len(self.idx2meancos))
print(len(self.idx2cos), len(self.idx2meancos), len(self.idx2flag))
else:
self.imgidx = list(self.imgrec.keys)
if shuffle:
@@ -743,8 +743,12 @@ class FaceImageIter(io.DataIter):
header, img = recordio.unpack(s)
label = header.label
if self.output_c2c:
meancos = self.idx2meancos[idx]
label = [label, meancos]
#v = self.idx2meancos[idx]
v = 0.5
count = self.idx2flag[idx]
if count>=self.output_c2c:
v = 0.4
label = [label, v]
else:
if not isinstance(label, numbers.Number):
label = label[0]
@@ -869,26 +873,17 @@ class FaceImageIter(io.DataIter):
for ll in xrange(batch_label.shape[1]):
v = label[ll]
if ll>0:
c2c = v
#m = min(0.55, max(0.3,math.log(c2c+1)*4-1.85))
#c2c = v
#_param = [0.5, 0.4, 0.85, 0.75]
#_a = (_param[1]-_param[0])/(_param[3]-_param[2])
#m = _param[1]+_a*(c2c-_param[3])
#m = min(_param[0], max(_param[1],m))
#v = math.cos(m)
#v = v*v
#_param = [0.5, 0.3, 0.85, 0.7]
_param = [0.5, 0.4, 0.85, 0.75]
#_param = [0.55, 0.4, 0.9, 0.75]
_a = (_param[1]-_param[0])/(_param[3]-_param[2])
m = _param[1]+_a*(c2c-_param[3])
m = min(_param[0], max(_param[1],m))
#m = 0.5
#if c2c<0.77:
# m = 0.3
#elif c2c<0.82:
# m = 0.4
#elif c2c>0.88:
# m = 0.55
m = v
v = math.cos(m)
v = v*v
#print('c2c', i,c2c,m,v)
#print('m', i,m,v)
batch_label[i][ll] = v
else: