This commit is contained in:
Jia Guo
2018-03-12 16:04:17 +08:00
2 changed files with 8 additions and 1 deletions

View File

@@ -59,7 +59,7 @@ class FaceImageIter(io.DataIter):
path_imgrec = None,
shuffle=False, aug_list=None, mean = None,
rand_mirror = False, cutoff = 0,
c2c_threshold = 0.0, output_c2c = 0, c2c_mode = -10,
c2c_threshold = 0.0, output_c2c = 0, c2c_mode = -10, limit = 0,
ctx_num = 0, images_per_identity = 0, data_extra = None, hard_mining = False,
triplet_params = None, coco_mode = False,
mx_model = None,
@@ -187,6 +187,9 @@ class FaceImageIter(io.DataIter):
print('id2range', len(self.id2range))
print(len(self.idx2cos), len(self.idx2meancos), len(self.idx2flag))
print('c2c_stat', c2c_stat)
if limit>0 and limit<len(self.imgidx):
random.shuffle(self.imgidx)
self.imgidx = self.imgidx[0:limit]
else:
self.imgidx = list(self.imgrec.keys)
if shuffle:

View File

@@ -152,6 +152,8 @@ def parse_args():
help='')
parser.add_argument('--output-c2c', type=int, default=0,
help='')
parser.add_argument('--train-limit', type=int, default=0,
help='')
parser.add_argument('--margin', type=int, default=4,
help='')
parser.add_argument('--beta', type=float, default=1000.,
@@ -858,6 +860,7 @@ def train_net(args):
c2c_threshold = args.c2c_threshold,
output_c2c = args.output_c2c,
c2c_mode = args.c2c_mode,
limit = args.train_limit,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_extra = data_extra,
@@ -882,6 +885,7 @@ def train_net(args):
c2c_threshold = args.c2c_threshold,
output_c2c = args.output_c2c,
c2c_mode = args.c2c_mode,
limit = args.train_limit,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_extra = data_extra,