add data_dir

This commit is contained in:
nttstar
2017-11-16 20:43:43 +08:00
parent 567caeb355
commit 6d6424175a

View File

@@ -67,6 +67,8 @@ class AccMetric(mx.metric.EvalMetric):
def parse_args():
parser = argparse.ArgumentParser(description='Train face network')
# general
parser.add_argument('--data-dir', default='',
help='')
parser.add_argument('--prefix', default='../model/spherefacei',
help='directory to save model.')
parser.add_argument('--pretrained', default='../model/resnet-152',
@@ -274,26 +276,24 @@ def train_net(args):
os.environ['BETA'] = str(args.beta)
args.use_val = False
path_imgrec = None
path_imglist = None
val_rec = None
val_path = None
path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new"
#path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new"
#path_imglist = "/raid5data/dplearn/faceinsight_align_webface_clean.lst.new"
args.num_classes = 10572 #webface
for line in open(os.path.join(args.data_dir, 'property')):
args.num_classes = int(line.strip())
assert(args.num_classes>0)
print('num_classes', args.num_classes)
path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
args.num_classes = 81017
path_imgrec = "/opt/jiaguo/faces_celeb/train.rec"
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
path_imgrec = os.path.join(args.data_dir, "train.rec")
val_rec = os.path.join(args.data_dir, "val.rec")
#args.num_classes = 10572 #webface
#args.num_classes = 81017
#args.num_classes = 82395
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst3"
#args.num_classes = 81013
path_imglist = "/raid5data/dplearn/faces_normed/train.lst"
args.num_classes = 82395
args.use_val = False
val_path = "/raid5data/dplearn/faces_normed/val.lst"
path_imgrec = "/opt/jiaguo/faces_normed/train.rec"
val_rec = "/opt/jiaguo/faces_normed/val.rec"
if args.loss_type==1 and args.num_classes>40000:
args.beta_freeze = 5000
@@ -456,7 +456,7 @@ def train_net(args):
#opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=1.0)
_cb = mx.callback.Speedometer(args.batch_size, 10)
lfw_dir = '/raid5data/dplearn/lfw_mtcnn2'
lfw_dir = os.path.join(args.data_dir,'lfw')
lfw_pairs = lfw.read_pairs(os.path.join(lfw_dir, 'pairs.txt'))
lfw_paths, issame_list = lfw.get_paths(lfw_dir, lfw_pairs, 'jpg')
imgs = []