mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-20 00:10:28 +00:00
add data_dir
This commit is contained in:
@@ -67,6 +67,8 @@ class AccMetric(mx.metric.EvalMetric):
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train face network')
|
||||
# general
|
||||
parser.add_argument('--data-dir', default='',
|
||||
help='')
|
||||
parser.add_argument('--prefix', default='../model/spherefacei',
|
||||
help='directory to save model.')
|
||||
parser.add_argument('--pretrained', default='../model/resnet-152',
|
||||
@@ -274,26 +276,24 @@ def train_net(args):
|
||||
os.environ['BETA'] = str(args.beta)
|
||||
args.use_val = False
|
||||
path_imgrec = None
|
||||
path_imglist = None
|
||||
val_rec = None
|
||||
val_path = None
|
||||
|
||||
path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new"
|
||||
#path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new"
|
||||
#path_imglist = "/raid5data/dplearn/faceinsight_align_webface_clean.lst.new"
|
||||
args.num_classes = 10572 #webface
|
||||
for line in open(os.path.join(args.data_dir, 'property')):
|
||||
args.num_classes = int(line.strip())
|
||||
assert(args.num_classes>0)
|
||||
print('num_classes', args.num_classes)
|
||||
|
||||
path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
|
||||
args.num_classes = 81017
|
||||
path_imgrec = "/opt/jiaguo/faces_celeb/train.rec"
|
||||
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
|
||||
path_imgrec = os.path.join(args.data_dir, "train.rec")
|
||||
val_rec = os.path.join(args.data_dir, "val.rec")
|
||||
#args.num_classes = 10572 #webface
|
||||
#args.num_classes = 81017
|
||||
#args.num_classes = 82395
|
||||
|
||||
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst3"
|
||||
#args.num_classes = 81013
|
||||
|
||||
path_imglist = "/raid5data/dplearn/faces_normed/train.lst"
|
||||
args.num_classes = 82395
|
||||
args.use_val = False
|
||||
val_path = "/raid5data/dplearn/faces_normed/val.lst"
|
||||
path_imgrec = "/opt/jiaguo/faces_normed/train.rec"
|
||||
val_rec = "/opt/jiaguo/faces_normed/val.rec"
|
||||
|
||||
if args.loss_type==1 and args.num_classes>40000:
|
||||
args.beta_freeze = 5000
|
||||
@@ -456,7 +456,7 @@ def train_net(args):
|
||||
#opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=1.0)
|
||||
_cb = mx.callback.Speedometer(args.batch_size, 10)
|
||||
|
||||
lfw_dir = '/raid5data/dplearn/lfw_mtcnn2'
|
||||
lfw_dir = os.path.join(args.data_dir,'lfw')
|
||||
lfw_pairs = lfw.read_pairs(os.path.join(lfw_dir, 'pairs.txt'))
|
||||
lfw_paths, issame_list = lfw.get_paths(lfw_dir, lfw_pairs, 'jpg')
|
||||
imgs = []
|
||||
|
||||
Reference in New Issue
Block a user