diff --git a/src/data.py b/src/data.py index e726e15..72e29f6 100644 --- a/src/data.py +++ b/src/data.py @@ -114,6 +114,7 @@ class FaceImageIter(io.DataIter): self.triplet_mode = True self.triplet_oseq_cur = 0 self.triplet_oseq_reset() + self.seq_min_size = self.batch_size*2 self.cur = 0 self.is_init = False #self.reset() @@ -180,74 +181,69 @@ class FaceImageIter(io.DataIter): def select_triplets(self): - self.triplet_index = 0 - self.triplets = [] - embeddings = None - bag_size = self.triplet_bag_size - batch_size = self.batch_size - #data = np.zeros( (bag_size,)+self.data_shape ) - #label = np.zeros( (bag_size,) ) - tag = [] - #idx = np.zeros( (bag_size,) ) - print('eval %d images..'%bag_size, self.triplet_oseq_cur) - if self.triplet_oseq_cur+bag_size>len(self.oseq): - self.triplet_oseq_reset() - print('eval %d images..'%bag_size, self.triplet_oseq_cur) - #print(data.shape) - data = nd.zeros( self.provide_data[0][1] ) - label = nd.zeros( self.provide_label[0][1] ) - ba = 0 - while True: - bb = min(ba+batch_size, bag_size) - if ba>=bb: - break - #_batch = self.data_iter.next() - #_data = _batch.data[0].asnumpy() - #print(_data.shape) - #_label = _batch.label[0].asnumpy() - #data[ba:bb,:,:,:] = _data - #label[ba:bb] = _label - for i in xrange(ba, bb): - _idx = self.oseq[i+self.triplet_oseq_cur] - s = self.imgrec.read_idx(_idx) - header, img = recordio.unpack(s) - img = self.imdecode(img) - data[i-ba][:] = self.postprocess_data(img) - label[i-ba][:] = header.label - tag.append( ( int(header.label), _idx) ) - #idx[i] = _idx + self.seq = [] + while len(self.seq)len(self.oseq): + self.triplet_oseq_reset() + print('eval %d images..'%bag_size, self.triplet_oseq_cur) + #print(data.shape) + data = nd.zeros( self.provide_data[0][1] ) + label = nd.zeros( self.provide_label[0][1] ) + ba = 0 + while True: + bb = min(ba+batch_size, bag_size) + if ba>=bb: + break + #_batch = self.data_iter.next() + #_data = _batch.data[0].asnumpy() + #print(_data.shape) + #_label = _batch.label[0].asnumpy() + #data[ba:bb,:,:,:] = _data + #label[ba:bb] = _label + for i in xrange(ba, bb): + _idx = self.oseq[i+self.triplet_oseq_cur] + s = self.imgrec.read_idx(_idx) + header, img = recordio.unpack(s) + img = self.imdecode(img) + data[i-ba][:] = self.postprocess_data(img) + label[i-ba][:] = header.label + tag.append( ( int(header.label), _idx) ) + #idx[i] = _idx - db = mx.io.DataBatch(data=(data,), label=(label,)) - self.mx_model.forward(db, is_train=False) - net_out = self.mx_model.get_outputs() - #print('eval for selecting triplets',ba,bb) - #print(net_out) - #print(len(net_out)) - #print(net_out[0].asnumpy()) - net_out = net_out[0].asnumpy() - #print(net_out) - #print('net_out', net_out.shape) - if embeddings is None: - embeddings = np.zeros( (bag_size, net_out.shape[1])) - embeddings[ba:bb,:] = net_out - ba = bb - assert len(tag)==bag_size - self.triplet_oseq_cur+=bag_size - embeddings = sklearn.preprocessing.normalize(embeddings) - nrof_images_per_class = [1] - for i in xrange(1, bag_size): - if tag[i][0]==tag[i-1][0]: - nrof_images_per_class[-1]+=1 - else: - nrof_images_per_class.append(1) - - triplets = self.pick_triplets(embeddings, nrof_images_per_class) # shape=(T,3) - if len(triplets)==0: - print('triplets 0, retry...') - self.select_triplets() - else: - print('triplets', len(triplets)) - self.seq = [] + db = mx.io.DataBatch(data=(data,), label=(label,)) + self.mx_model.forward(db, is_train=False) + net_out = self.mx_model.get_outputs() + #print('eval for selecting triplets',ba,bb) + #print(net_out) + #print(len(net_out)) + #print(net_out[0].asnumpy()) + net_out = net_out[0].asnumpy() + #print(net_out) + #print('net_out', net_out.shape) + if embeddings is None: + embeddings = np.zeros( (bag_size, net_out.shape[1])) + embeddings[ba:bb,:] = net_out + ba = bb + assert len(tag)==bag_size + self.triplet_oseq_cur+=bag_size + embeddings = sklearn.preprocessing.normalize(embeddings) + nrof_images_per_class = [1] + for i in xrange(1, bag_size): + if tag[i][0]==tag[i-1][0]: + nrof_images_per_class[-1]+=1 + else: + nrof_images_per_class.append(1) + + triplets = self.pick_triplets(embeddings, nrof_images_per_class) # shape=(T,3) + print('found triplets', len(triplets)) ba = 0 while True: bb = ba+self.per_batch_size//3 diff --git a/src/train_softmax.py b/src/train_softmax.py index 5d5f30c..e158de4 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -484,7 +484,7 @@ def train_net(args): label_names = (label_name,), ) - if len(data_dir_list)==1: + if len(data_dir_list)==1 and args.loss_type!=12: train_dataiter = FaceImageIter( batch_size = args.batch_size, data_shape = data_shape,