This commit is contained in:
nttstar
2018-02-08 19:35:55 +08:00
11 changed files with 370 additions and 23 deletions

View File

@@ -59,7 +59,7 @@ class FaceImageIter(io.DataIter):
path_imgrec = None,
shuffle=False, aug_list=None, mean = None,
rand_mirror = False,
c2c_threshold = 0.0, output_c2c = 0,
c2c_threshold = 0.0, output_c2c = 0, c2c_mode = -10,
ctx_num = 0, images_per_identity = 0, data_extra = None, hard_mining = False,
triplet_params = None, coco_mode = False,
mx_model = None,
@@ -74,6 +74,7 @@ class FaceImageIter(io.DataIter):
s = self.imgrec.read_idx(0)
header, _ = recordio.unpack(s)
self.idx2cos = {}
self.idx2flag = {}
self.idx2meancos = {}
self.c2c_auto = False
if output_c2c or c2c_threshold>0.0:
@@ -82,7 +83,11 @@ class FaceImageIter(io.DataIter):
if os.path.exists(path_c2c):
for line in open(path_c2c, 'r'):
vec = line.strip().split(',')
self.idx2cos[int(vec[0])] = float(vec[1])
idx = int(vec[0])
self.idx2cos[idx] = float(vec[1])
self.idx2flag[idx] = 1
if len(vec)>2:
self.idx2flag[idx] = int(vec[2])
else:
self.c2c_auto = True
self.c2c_step = 10000
@@ -91,10 +96,65 @@ class FaceImageIter(io.DataIter):
self.header0 = (int(header.label[0]), int(header.label[1]))
#assert(header.flag==1)
self.imgidx = range(1, int(header.label[0]))
if c2c_threshold>0.0:
if c2c_mode==0:
imgidx2 = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if f!=1:
continue
imgidx2.append(idx)
print('idx count', len(self.imgidx), len(imgidx2))
self.imgidx = imgidx2
elif c2c_mode==1:
imgidx2 = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if f==2 and c>=0.05:
continue
imgidx2.append(idx)
print('idx count', len(self.imgidx), len(imgidx2))
self.imgidx = imgidx2
elif c2c_mode==2:
imgidx2 = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if f==2 and c>=0.1:
continue
imgidx2.append(idx)
print('idx count', len(self.imgidx), len(imgidx2))
self.imgidx = imgidx2
elif c2c_mode==-1:
imgidx2 = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if f==2:
continue
if c<0.1:
continue
imgidx2.append(idx)
print('idx count', len(self.imgidx), len(imgidx2))
self.imgidx = imgidx2
elif c2c_mode==-2:
imgidx2 = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if f==2:
continue
if c<0.2:
continue
imgidx2.append(idx)
print('idx count', len(self.imgidx), len(imgidx2))
self.imgidx = imgidx2
elif c2c_threshold>0.0:
imgidx2 = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if c<c2c_threshold:
continue
imgidx2.append(idx)
@@ -684,6 +744,9 @@ class FaceImageIter(io.DataIter):
if self.output_c2c:
meancos = self.idx2meancos[idx]
label = [label, meancos]
else:
if isinstance(label, list):
label = label[0]
return label, img, None, None
else:
label, fname, bbox, landmark = self.imglist[idx]