mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-17 14:26:08 +00:00
fix
This commit is contained in:
@@ -28,6 +28,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from scipy import misc
|
||||
from sklearn.model_selection import KFold
|
||||
@@ -253,7 +254,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--target', default='lfw,cfp_ff,cfp_fp', help='test targets.')
|
||||
parser.add_argument('--gpu', default=0, type=int, help='gpu id')
|
||||
parser.add_argument('--batch-size', default=128, type=int, help='')
|
||||
args = parse_arguments(sys.argv[1:])
|
||||
args = parser.parse_args()
|
||||
|
||||
image_size = [int(x) for x in args.image_size.split(',')]
|
||||
ctx = mx.gpu(args.gpu)
|
||||
@@ -263,6 +264,8 @@ if __name__ == '__main__':
|
||||
epoch = int(epoch)
|
||||
print('loading',prefix, epoch)
|
||||
model = mx.mod.Module.load(prefix, epoch, context = ctx)
|
||||
model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
|
||||
#model.init_params()
|
||||
nets.append(model)
|
||||
|
||||
ver_list = []
|
||||
@@ -270,7 +273,7 @@ if __name__ == '__main__':
|
||||
for name in args.target.split(','):
|
||||
path = os.path.join(args.data_dir,name+".bin")
|
||||
if os.path.exists(path):
|
||||
print('ver', name)
|
||||
print('loading.. ', name)
|
||||
data_set = load_bin(path, image_size)
|
||||
ver_list.append(data_set)
|
||||
ver_name_list.append(name)
|
||||
|
||||
Reference in New Issue
Block a user