mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-19 07:27:52 +00:00
do not reset
This commit is contained in:
26
src/data.py
26
src/data.py
@@ -33,6 +33,7 @@ class FaceImageIter(io.DataIter):
|
||||
path_imgrec = None,
|
||||
shuffle=False, aug_list=None, mean = None,
|
||||
rand_mirror = False,
|
||||
ctx_num = 0, images_per_identity = 0,
|
||||
data_name='data', label_name='softmax_label', **kwargs):
|
||||
super(FaceImageIter, self).__init__()
|
||||
assert path_imgrec
|
||||
@@ -80,18 +81,33 @@ class FaceImageIter(io.DataIter):
|
||||
self.rand_mirror = rand_mirror
|
||||
#self.cast_aug = mx.image.CastAug()
|
||||
#self.color_aug = mx.image.ColorJitterAug(0.4, 0.4, 0.4)
|
||||
|
||||
self.ctx_num = ctx_num
|
||||
self.per_batch_size = int(self.batch_size/self.ctx_num)
|
||||
self.images_per_identity = images_per_identity
|
||||
if self.images_per_identity>0:
|
||||
self.identities = int(self.per_batch_size/self.images_per_identity)
|
||||
print(self.images_per_identity, self.identities)
|
||||
self.cur = 0
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the iterator to the beginning of the data."""
|
||||
print('call reset()')
|
||||
if self.shuffle:
|
||||
random.shuffle(self.seq)
|
||||
if self.imgrec is not None:
|
||||
self.imgrec.reset()
|
||||
self.cur = 0
|
||||
if self.images_per_identity>0:
|
||||
self.seq = []
|
||||
for _id, _v in self.idx2range.iteritems():
|
||||
_list = range(_v)
|
||||
if self.shuffle:
|
||||
random.shuffle(_list)
|
||||
for i in xrange(self.images_per_identity):
|
||||
_idx = _list[i%len(_list)]
|
||||
self.seq.append(_id)
|
||||
else:
|
||||
if self.shuffle:
|
||||
random.shuffle(self.seq)
|
||||
if self.seq is None and self.imgrec is not None:
|
||||
self.imgrec.reset()
|
||||
|
||||
def num_samples(self):
|
||||
return len(self.seq)
|
||||
|
||||
Reference in New Issue
Block a user