triplet fix

This commit is contained in:
nttstar
2017-12-26 12:30:19 +08:00
parent 8f5461d74d
commit a4d2910742
2 changed files with 64 additions and 68 deletions

View File

@@ -114,6 +114,7 @@ class FaceImageIter(io.DataIter):
self.triplet_mode = True
self.triplet_oseq_cur = 0
self.triplet_oseq_reset()
self.seq_min_size = self.batch_size*2
self.cur = 0
self.is_init = False
#self.reset()
@@ -180,74 +181,69 @@ class FaceImageIter(io.DataIter):
def select_triplets(self):
self.triplet_index = 0
self.triplets = []
embeddings = None
bag_size = self.triplet_bag_size
batch_size = self.batch_size
#data = np.zeros( (bag_size,)+self.data_shape )
#label = np.zeros( (bag_size,) )
tag = []
#idx = np.zeros( (bag_size,) )
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
if self.triplet_oseq_cur+bag_size>len(self.oseq):
self.triplet_oseq_reset()
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
#print(data.shape)
data = nd.zeros( self.provide_data[0][1] )
label = nd.zeros( self.provide_label[0][1] )
ba = 0
while True:
bb = min(ba+batch_size, bag_size)
if ba>=bb:
break
#_batch = self.data_iter.next()
#_data = _batch.data[0].asnumpy()
#print(_data.shape)
#_label = _batch.label[0].asnumpy()
#data[ba:bb,:,:,:] = _data
#label[ba:bb] = _label
for i in xrange(ba, bb):
_idx = self.oseq[i+self.triplet_oseq_cur]
s = self.imgrec.read_idx(_idx)
header, img = recordio.unpack(s)
img = self.imdecode(img)
data[i-ba][:] = self.postprocess_data(img)
label[i-ba][:] = header.label
tag.append( ( int(header.label), _idx) )
#idx[i] = _idx
self.seq = []
while len(self.seq)<self.seq_min_size:
embeddings = None
bag_size = self.triplet_bag_size
batch_size = self.batch_size
#data = np.zeros( (bag_size,)+self.data_shape )
#label = np.zeros( (bag_size,) )
tag = []
#idx = np.zeros( (bag_size,) )
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
if self.triplet_oseq_cur+bag_size>len(self.oseq):
self.triplet_oseq_reset()
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
#print(data.shape)
data = nd.zeros( self.provide_data[0][1] )
label = nd.zeros( self.provide_label[0][1] )
ba = 0
while True:
bb = min(ba+batch_size, bag_size)
if ba>=bb:
break
#_batch = self.data_iter.next()
#_data = _batch.data[0].asnumpy()
#print(_data.shape)
#_label = _batch.label[0].asnumpy()
#data[ba:bb,:,:,:] = _data
#label[ba:bb] = _label
for i in xrange(ba, bb):
_idx = self.oseq[i+self.triplet_oseq_cur]
s = self.imgrec.read_idx(_idx)
header, img = recordio.unpack(s)
img = self.imdecode(img)
data[i-ba][:] = self.postprocess_data(img)
label[i-ba][:] = header.label
tag.append( ( int(header.label), _idx) )
#idx[i] = _idx
db = mx.io.DataBatch(data=(data,), label=(label,))
self.mx_model.forward(db, is_train=False)
net_out = self.mx_model.get_outputs()
#print('eval for selecting triplets',ba,bb)
#print(net_out)
#print(len(net_out))
#print(net_out[0].asnumpy())
net_out = net_out[0].asnumpy()
#print(net_out)
#print('net_out', net_out.shape)
if embeddings is None:
embeddings = np.zeros( (bag_size, net_out.shape[1]))
embeddings[ba:bb,:] = net_out
ba = bb
assert len(tag)==bag_size
self.triplet_oseq_cur+=bag_size
embeddings = sklearn.preprocessing.normalize(embeddings)
nrof_images_per_class = [1]
for i in xrange(1, bag_size):
if tag[i][0]==tag[i-1][0]:
nrof_images_per_class[-1]+=1
else:
nrof_images_per_class.append(1)
triplets = self.pick_triplets(embeddings, nrof_images_per_class) # shape=(T,3)
if len(triplets)==0:
print('triplets 0, retry...')
self.select_triplets()
else:
print('triplets', len(triplets))
self.seq = []
db = mx.io.DataBatch(data=(data,), label=(label,))
self.mx_model.forward(db, is_train=False)
net_out = self.mx_model.get_outputs()
#print('eval for selecting triplets',ba,bb)
#print(net_out)
#print(len(net_out))
#print(net_out[0].asnumpy())
net_out = net_out[0].asnumpy()
#print(net_out)
#print('net_out', net_out.shape)
if embeddings is None:
embeddings = np.zeros( (bag_size, net_out.shape[1]))
embeddings[ba:bb,:] = net_out
ba = bb
assert len(tag)==bag_size
self.triplet_oseq_cur+=bag_size
embeddings = sklearn.preprocessing.normalize(embeddings)
nrof_images_per_class = [1]
for i in xrange(1, bag_size):
if tag[i][0]==tag[i-1][0]:
nrof_images_per_class[-1]+=1
else:
nrof_images_per_class.append(1)
triplets = self.pick_triplets(embeddings, nrof_images_per_class) # shape=(T,3)
print('found triplets', len(triplets))
ba = 0
while True:
bb = ba+self.per_batch_size//3

View File

@@ -484,7 +484,7 @@ def train_net(args):
label_names = (label_name,),
)
if len(data_dir_list)==1:
if len(data_dir_list)==1 and args.loss_type!=12:
train_dataiter = FaceImageIter(
batch_size = args.batch_size,
data_shape = data_shape,