From 6d6424175adea0dbc0009e3de3eaef4fceff2967 Mon Sep 17 00:00:00 2001 From: nttstar Date: Thu, 16 Nov 2017 20:43:43 +0800 Subject: [PATCH] add data_dir --- src/train_softmax.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/train_softmax.py b/src/train_softmax.py index 83cbe0a..a96a378 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -67,6 +67,8 @@ class AccMetric(mx.metric.EvalMetric): def parse_args(): parser = argparse.ArgumentParser(description='Train face network') # general + parser.add_argument('--data-dir', default='', + help='') parser.add_argument('--prefix', default='../model/spherefacei', help='directory to save model.') parser.add_argument('--pretrained', default='../model/resnet-152', @@ -274,26 +276,24 @@ def train_net(args): os.environ['BETA'] = str(args.beta) args.use_val = False path_imgrec = None + path_imglist = None val_rec = None - val_path = None - path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new" + #path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new" #path_imglist = "/raid5data/dplearn/faceinsight_align_webface_clean.lst.new" - args.num_classes = 10572 #webface + for line in open(os.path.join(args.data_dir, 'property')): + args.num_classes = int(line.strip()) + assert(args.num_classes>0) + print('num_classes', args.num_classes) - path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2" - args.num_classes = 81017 - path_imgrec = "/opt/jiaguo/faces_celeb/train.rec" + #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2" + path_imgrec = os.path.join(args.data_dir, "train.rec") + val_rec = os.path.join(args.data_dir, "val.rec") + #args.num_classes = 10572 #webface + #args.num_classes = 81017 + #args.num_classes = 82395 - #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst3" - #args.num_classes = 81013 - path_imglist = "/raid5data/dplearn/faces_normed/train.lst" - args.num_classes = 82395 - args.use_val = False - val_path = "/raid5data/dplearn/faces_normed/val.lst" - path_imgrec = "/opt/jiaguo/faces_normed/train.rec" - val_rec = "/opt/jiaguo/faces_normed/val.rec" if args.loss_type==1 and args.num_classes>40000: args.beta_freeze = 5000 @@ -456,7 +456,7 @@ def train_net(args): #opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=1.0) _cb = mx.callback.Speedometer(args.batch_size, 10) - lfw_dir = '/raid5data/dplearn/lfw_mtcnn2' + lfw_dir = os.path.join(args.data_dir,'lfw') lfw_pairs = lfw.read_pairs(os.path.join(lfw_dir, 'pairs.txt')) lfw_paths, issame_list = lfw.get_paths(lfw_dir, lfw_pairs, 'jpg') imgs = []