This commit is contained in:
nttstar
2017-12-07 14:13:48 +08:00
parent c8c097771f
commit 97e98f654f

View File

@@ -28,6 +28,7 @@ from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from scipy import misc
from sklearn.model_selection import KFold
@@ -253,7 +254,7 @@ if __name__ == '__main__':
parser.add_argument('--target', default='lfw,cfp_ff,cfp_fp', help='test targets.')
parser.add_argument('--gpu', default=0, type=int, help='gpu id')
parser.add_argument('--batch-size', default=128, type=int, help='')
args = parse_arguments(sys.argv[1:])
args = parser.parse_args()
image_size = [int(x) for x in args.image_size.split(',')]
ctx = mx.gpu(args.gpu)
@@ -263,6 +264,8 @@ if __name__ == '__main__':
epoch = int(epoch)
print('loading',prefix, epoch)
model = mx.mod.Module.load(prefix, epoch, context = ctx)
model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
#model.init_params()
nets.append(model)
ver_list = []
@@ -270,7 +273,7 @@ if __name__ == '__main__':
for name in args.target.split(','):
path = os.path.join(args.data_dir,name+".bin")
if os.path.exists(path):
print('ver', name)
print('loading.. ', name)
data_set = load_bin(path, image_size)
ver_list.append(data_set)
ver_name_list.append(name)