do not reset

This commit is contained in:
Jia Guo
2017-12-12 22:55:18 +08:00
parent 46526534e2
commit 6f93b2f7b2
2 changed files with 62 additions and 49 deletions

View File

@@ -33,6 +33,7 @@ class FaceImageIter(io.DataIter):
path_imgrec = None,
shuffle=False, aug_list=None, mean = None,
rand_mirror = False,
ctx_num = 0, images_per_identity = 0,
data_name='data', label_name='softmax_label', **kwargs):
super(FaceImageIter, self).__init__()
assert path_imgrec
@@ -80,18 +81,33 @@ class FaceImageIter(io.DataIter):
self.rand_mirror = rand_mirror
#self.cast_aug = mx.image.CastAug()
#self.color_aug = mx.image.ColorJitterAug(0.4, 0.4, 0.4)
self.ctx_num = ctx_num
self.per_batch_size = int(self.batch_size/self.ctx_num)
self.images_per_identity = images_per_identity
if self.images_per_identity>0:
self.identities = int(self.per_batch_size/self.images_per_identity)
print(self.images_per_identity, self.identities)
self.cur = 0
self.reset()
def reset(self):
"""Resets the iterator to the beginning of the data."""
print('call reset()')
if self.shuffle:
random.shuffle(self.seq)
if self.imgrec is not None:
self.imgrec.reset()
self.cur = 0
if self.images_per_identity>0:
self.seq = []
for _id, _v in self.idx2range.iteritems():
_list = range(_v)
if self.shuffle:
random.shuffle(_list)
for i in xrange(self.images_per_identity):
_idx = _list[i%len(_list)]
self.seq.append(_id)
else:
if self.shuffle:
random.shuffle(self.seq)
if self.seq is None and self.imgrec is not None:
self.imgrec.reset()
def num_samples(self):
return len(self.seq)