mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-18 14:55:42 +00:00
triplet fix
This commit is contained in:
130
src/data.py
130
src/data.py
@@ -114,6 +114,7 @@ class FaceImageIter(io.DataIter):
|
||||
self.triplet_mode = True
|
||||
self.triplet_oseq_cur = 0
|
||||
self.triplet_oseq_reset()
|
||||
self.seq_min_size = self.batch_size*2
|
||||
self.cur = 0
|
||||
self.is_init = False
|
||||
#self.reset()
|
||||
@@ -180,74 +181,69 @@ class FaceImageIter(io.DataIter):
|
||||
|
||||
|
||||
def select_triplets(self):
|
||||
self.triplet_index = 0
|
||||
self.triplets = []
|
||||
embeddings = None
|
||||
bag_size = self.triplet_bag_size
|
||||
batch_size = self.batch_size
|
||||
#data = np.zeros( (bag_size,)+self.data_shape )
|
||||
#label = np.zeros( (bag_size,) )
|
||||
tag = []
|
||||
#idx = np.zeros( (bag_size,) )
|
||||
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
|
||||
if self.triplet_oseq_cur+bag_size>len(self.oseq):
|
||||
self.triplet_oseq_reset()
|
||||
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
|
||||
#print(data.shape)
|
||||
data = nd.zeros( self.provide_data[0][1] )
|
||||
label = nd.zeros( self.provide_label[0][1] )
|
||||
ba = 0
|
||||
while True:
|
||||
bb = min(ba+batch_size, bag_size)
|
||||
if ba>=bb:
|
||||
break
|
||||
#_batch = self.data_iter.next()
|
||||
#_data = _batch.data[0].asnumpy()
|
||||
#print(_data.shape)
|
||||
#_label = _batch.label[0].asnumpy()
|
||||
#data[ba:bb,:,:,:] = _data
|
||||
#label[ba:bb] = _label
|
||||
for i in xrange(ba, bb):
|
||||
_idx = self.oseq[i+self.triplet_oseq_cur]
|
||||
s = self.imgrec.read_idx(_idx)
|
||||
header, img = recordio.unpack(s)
|
||||
img = self.imdecode(img)
|
||||
data[i-ba][:] = self.postprocess_data(img)
|
||||
label[i-ba][:] = header.label
|
||||
tag.append( ( int(header.label), _idx) )
|
||||
#idx[i] = _idx
|
||||
self.seq = []
|
||||
while len(self.seq)<self.seq_min_size:
|
||||
embeddings = None
|
||||
bag_size = self.triplet_bag_size
|
||||
batch_size = self.batch_size
|
||||
#data = np.zeros( (bag_size,)+self.data_shape )
|
||||
#label = np.zeros( (bag_size,) )
|
||||
tag = []
|
||||
#idx = np.zeros( (bag_size,) )
|
||||
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
|
||||
if self.triplet_oseq_cur+bag_size>len(self.oseq):
|
||||
self.triplet_oseq_reset()
|
||||
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
|
||||
#print(data.shape)
|
||||
data = nd.zeros( self.provide_data[0][1] )
|
||||
label = nd.zeros( self.provide_label[0][1] )
|
||||
ba = 0
|
||||
while True:
|
||||
bb = min(ba+batch_size, bag_size)
|
||||
if ba>=bb:
|
||||
break
|
||||
#_batch = self.data_iter.next()
|
||||
#_data = _batch.data[0].asnumpy()
|
||||
#print(_data.shape)
|
||||
#_label = _batch.label[0].asnumpy()
|
||||
#data[ba:bb,:,:,:] = _data
|
||||
#label[ba:bb] = _label
|
||||
for i in xrange(ba, bb):
|
||||
_idx = self.oseq[i+self.triplet_oseq_cur]
|
||||
s = self.imgrec.read_idx(_idx)
|
||||
header, img = recordio.unpack(s)
|
||||
img = self.imdecode(img)
|
||||
data[i-ba][:] = self.postprocess_data(img)
|
||||
label[i-ba][:] = header.label
|
||||
tag.append( ( int(header.label), _idx) )
|
||||
#idx[i] = _idx
|
||||
|
||||
db = mx.io.DataBatch(data=(data,), label=(label,))
|
||||
self.mx_model.forward(db, is_train=False)
|
||||
net_out = self.mx_model.get_outputs()
|
||||
#print('eval for selecting triplets',ba,bb)
|
||||
#print(net_out)
|
||||
#print(len(net_out))
|
||||
#print(net_out[0].asnumpy())
|
||||
net_out = net_out[0].asnumpy()
|
||||
#print(net_out)
|
||||
#print('net_out', net_out.shape)
|
||||
if embeddings is None:
|
||||
embeddings = np.zeros( (bag_size, net_out.shape[1]))
|
||||
embeddings[ba:bb,:] = net_out
|
||||
ba = bb
|
||||
assert len(tag)==bag_size
|
||||
self.triplet_oseq_cur+=bag_size
|
||||
embeddings = sklearn.preprocessing.normalize(embeddings)
|
||||
nrof_images_per_class = [1]
|
||||
for i in xrange(1, bag_size):
|
||||
if tag[i][0]==tag[i-1][0]:
|
||||
nrof_images_per_class[-1]+=1
|
||||
else:
|
||||
nrof_images_per_class.append(1)
|
||||
|
||||
triplets = self.pick_triplets(embeddings, nrof_images_per_class) # shape=(T,3)
|
||||
if len(triplets)==0:
|
||||
print('triplets 0, retry...')
|
||||
self.select_triplets()
|
||||
else:
|
||||
print('triplets', len(triplets))
|
||||
self.seq = []
|
||||
db = mx.io.DataBatch(data=(data,), label=(label,))
|
||||
self.mx_model.forward(db, is_train=False)
|
||||
net_out = self.mx_model.get_outputs()
|
||||
#print('eval for selecting triplets',ba,bb)
|
||||
#print(net_out)
|
||||
#print(len(net_out))
|
||||
#print(net_out[0].asnumpy())
|
||||
net_out = net_out[0].asnumpy()
|
||||
#print(net_out)
|
||||
#print('net_out', net_out.shape)
|
||||
if embeddings is None:
|
||||
embeddings = np.zeros( (bag_size, net_out.shape[1]))
|
||||
embeddings[ba:bb,:] = net_out
|
||||
ba = bb
|
||||
assert len(tag)==bag_size
|
||||
self.triplet_oseq_cur+=bag_size
|
||||
embeddings = sklearn.preprocessing.normalize(embeddings)
|
||||
nrof_images_per_class = [1]
|
||||
for i in xrange(1, bag_size):
|
||||
if tag[i][0]==tag[i-1][0]:
|
||||
nrof_images_per_class[-1]+=1
|
||||
else:
|
||||
nrof_images_per_class.append(1)
|
||||
|
||||
triplets = self.pick_triplets(embeddings, nrof_images_per_class) # shape=(T,3)
|
||||
print('found triplets', len(triplets))
|
||||
ba = 0
|
||||
while True:
|
||||
bb = ba+self.per_batch_size//3
|
||||
|
||||
@@ -484,7 +484,7 @@ def train_net(args):
|
||||
label_names = (label_name,),
|
||||
)
|
||||
|
||||
if len(data_dir_list)==1:
|
||||
if len(data_dir_list)==1 and args.loss_type!=12:
|
||||
train_dataiter = FaceImageIter(
|
||||
batch_size = args.batch_size,
|
||||
data_shape = data_shape,
|
||||
|
||||
Reference in New Issue
Block a user