mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-14 12:17:55 +00:00
162 lines
6.1 KiB
Python
162 lines
6.1 KiB
Python
import argparse
|
|
import pprint
|
|
import mxnet as mx
|
|
|
|
from ..logger import logger
|
|
from ..config import config, default, generate_config
|
|
from ..symbol import *
|
|
from ..dataset import *
|
|
from ..core.loader import TestLoader
|
|
from ..core.tester import Predictor, pred_eval
|
|
from ..utils.load_model import load_param
|
|
|
|
|
|
def test_rcnn(network, dataset, image_set, root_path, dataset_path, ctx,
|
|
prefix, epoch, vis, shuffle, has_rpn, proposal, thresh):
|
|
# set config
|
|
if has_rpn:
|
|
config.TEST.HAS_RPN = True
|
|
|
|
# print config
|
|
logger.info(pprint.pformat(config))
|
|
|
|
# load symbol and testing data
|
|
if has_rpn:
|
|
sym = eval('get_' + network + '_test')(num_classes=config.NUM_CLASSES,
|
|
num_anchors=config.NUM_ANCHORS)
|
|
imdb = eval(dataset)(image_set, root_path, dataset_path)
|
|
roidb = imdb.gt_roidb()
|
|
else:
|
|
sym = eval('get_' + network +
|
|
'_rcnn_test')(num_classes=config.NUM_CLASSES)
|
|
imdb = eval(dataset)(image_set, root_path, dataset_path)
|
|
gt_roidb = imdb.gt_roidb()
|
|
roidb = eval('imdb.' + proposal + '_roidb')(gt_roidb)
|
|
|
|
# get test data iter
|
|
test_data = TestLoader(roidb,
|
|
batch_size=1,
|
|
shuffle=shuffle,
|
|
has_rpn=has_rpn)
|
|
|
|
# load model
|
|
arg_params, aux_params = load_param(prefix,
|
|
epoch,
|
|
convert=True,
|
|
ctx=ctx,
|
|
process=True)
|
|
|
|
# infer shape
|
|
data_shape_dict = dict(test_data.provide_data)
|
|
arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
|
|
arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
|
|
aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
|
|
|
|
# check parameters
|
|
for k in sym.list_arguments():
|
|
if k in data_shape_dict or 'label' in k:
|
|
continue
|
|
assert k in arg_params, k + ' not initialized'
|
|
assert arg_params[k].shape == arg_shape_dict[k], \
|
|
'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape)
|
|
for k in sym.list_auxiliary_states():
|
|
assert k in aux_params, k + ' not initialized'
|
|
assert aux_params[k].shape == aux_shape_dict[k], \
|
|
'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape)
|
|
|
|
# decide maximum shape
|
|
data_names = [k[0] for k in test_data.provide_data]
|
|
label_names = None
|
|
max_data_shape = [('data', (1, 3, max([v[0] for v in config.SCALES]),
|
|
max([v[1] for v in config.SCALES])))]
|
|
if not has_rpn:
|
|
max_data_shape.append(
|
|
('rois', (1, config.TEST.PROPOSAL_POST_NMS_TOP_N + 30, 5)))
|
|
|
|
# create predictor
|
|
predictor = Predictor(sym,
|
|
data_names,
|
|
label_names,
|
|
context=ctx,
|
|
max_data_shapes=max_data_shape,
|
|
provide_data=test_data.provide_data,
|
|
provide_label=test_data.provide_label,
|
|
arg_params=arg_params,
|
|
aux_params=aux_params)
|
|
|
|
# start detection
|
|
pred_eval(predictor, test_data, imdb, vis=vis, thresh=thresh)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
|
|
# general
|
|
parser.add_argument('--network',
|
|
help='network name',
|
|
default=default.network,
|
|
type=str)
|
|
parser.add_argument('--dataset',
|
|
help='dataset name',
|
|
default=default.dataset,
|
|
type=str)
|
|
args, rest = parser.parse_known_args()
|
|
generate_config(args.network, args.dataset)
|
|
parser.add_argument('--image_set',
|
|
help='image_set name',
|
|
default=default.test_image_set,
|
|
type=str)
|
|
parser.add_argument('--root_path',
|
|
help='output data folder',
|
|
default=default.root_path,
|
|
type=str)
|
|
parser.add_argument('--dataset_path',
|
|
help='dataset path',
|
|
default=default.dataset_path,
|
|
type=str)
|
|
# testing
|
|
parser.add_argument('--prefix',
|
|
help='model to test with',
|
|
default=default.rcnn_prefix,
|
|
type=str)
|
|
parser.add_argument('--epoch',
|
|
help='model to test with',
|
|
default=default.rcnn_epoch,
|
|
type=int)
|
|
parser.add_argument('--gpu',
|
|
help='GPU device to test with',
|
|
default=0,
|
|
type=int)
|
|
# rcnn
|
|
parser.add_argument('--vis',
|
|
help='turn on visualization',
|
|
action='store_true')
|
|
parser.add_argument('--thresh',
|
|
help='valid detection threshold',
|
|
default=1e-3,
|
|
type=float)
|
|
parser.add_argument('--shuffle',
|
|
help='shuffle data on visualization',
|
|
action='store_true')
|
|
parser.add_argument('--has_rpn',
|
|
help='generate proposals on the fly',
|
|
action='store_true')
|
|
parser.add_argument('--proposal',
|
|
help='can be ss for selective search or rpn',
|
|
default='rpn',
|
|
type=str)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
logger.info('Called with argument: %s' % args)
|
|
ctx = mx.gpu(args.gpu)
|
|
test_rcnn(args.network, args.dataset, args.image_set, args.root_path,
|
|
args.dataset_path, ctx, args.prefix, args.epoch, args.vis,
|
|
args.shuffle, args.has_rpn, args.proposal, args.thresh)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|