From 97e98f654fc36ebaf1bb3a43e2bfde7b43c068bf Mon Sep 17 00:00:00 2001 From: nttstar Date: Thu, 7 Dec 2017 14:13:48 +0800 Subject: [PATCH] fix --- src/eval/verification.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/eval/verification.py b/src/eval/verification.py index e9f3ea4..7ae554b 100644 --- a/src/eval/verification.py +++ b/src/eval/verification.py @@ -28,6 +28,7 @@ from __future__ import division from __future__ import print_function import os +import argparse import numpy as np from scipy import misc from sklearn.model_selection import KFold @@ -253,7 +254,7 @@ if __name__ == '__main__': parser.add_argument('--target', default='lfw,cfp_ff,cfp_fp', help='test targets.') parser.add_argument('--gpu', default=0, type=int, help='gpu id') parser.add_argument('--batch-size', default=128, type=int, help='') - args = parse_arguments(sys.argv[1:]) + args = parser.parse_args() image_size = [int(x) for x in args.image_size.split(',')] ctx = mx.gpu(args.gpu) @@ -263,6 +264,8 @@ if __name__ == '__main__': epoch = int(epoch) print('loading',prefix, epoch) model = mx.mod.Module.load(prefix, epoch, context = ctx) + model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) + #model.init_params() nets.append(model) ver_list = [] @@ -270,7 +273,7 @@ if __name__ == '__main__': for name in args.target.split(','): path = os.path.join(args.data_dir,name+".bin") if os.path.exists(path): - print('ver', name) + print('loading.. ', name) data_set = load_bin(path, image_size) ver_list.append(data_set) ver_name_list.append(name)