diff --git a/src/eval/verification.py b/src/eval/verification.py index e9f3ea4..7ae554b 100644 --- a/src/eval/verification.py +++ b/src/eval/verification.py @@ -28,6 +28,7 @@ from __future__ import division from __future__ import print_function import os +import argparse import numpy as np from scipy import misc from sklearn.model_selection import KFold @@ -253,7 +254,7 @@ if __name__ == '__main__': parser.add_argument('--target', default='lfw,cfp_ff,cfp_fp', help='test targets.') parser.add_argument('--gpu', default=0, type=int, help='gpu id') parser.add_argument('--batch-size', default=128, type=int, help='') - args = parse_arguments(sys.argv[1:]) + args = parser.parse_args() image_size = [int(x) for x in args.image_size.split(',')] ctx = mx.gpu(args.gpu) @@ -263,6 +264,8 @@ if __name__ == '__main__': epoch = int(epoch) print('loading',prefix, epoch) model = mx.mod.Module.load(prefix, epoch, context = ctx) + model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) + #model.init_params() nets.append(model) ver_list = [] @@ -270,7 +273,7 @@ if __name__ == '__main__': for name in args.target.split(','): path = os.path.join(args.data_dir,name+".bin") if os.path.exists(path): - print('ver', name) + print('loading.. ', name) data_set = load_bin(path, image_size) ver_list.append(data_set) ver_name_list.append(name)