Files
insightface/detection/RetinaFace/rcnn/utils/load_model.py
2020-11-06 13:59:21 +08:00

60 lines
2.0 KiB
Python

import mxnet as mx
def load_checkpoint(prefix, epoch):
"""
Load model checkpoint from file.
:param prefix: Prefix of model name.
:param epoch: Epoch number of model we would like to load.
:return: (arg_params, aux_params)
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
"""
save_dict = mx.nd.load('%s-%04d.params' % (prefix, epoch))
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
if tp == 'aux':
aux_params[name] = v
return arg_params, aux_params
def convert_context(params, ctx):
"""
:param params: dict of str to NDArray
:param ctx: the context to convert to
:return: dict of str of NDArray with context ctx
"""
new_params = dict()
for k, v in params.items():
new_params[k] = v.as_in_context(ctx)
return new_params
def load_param(prefix, epoch, convert=False, ctx=None, process=False):
"""
wrapper for load checkpoint
:param prefix: Prefix of model name.
:param epoch: Epoch number of model we would like to load.
:param convert: reference model should be converted to GPU NDArray first
:param ctx: if convert then ctx must be designated.
:param process: model should drop any test
:return: (arg_params, aux_params)
"""
arg_params, aux_params = load_checkpoint(prefix, epoch)
if convert:
if ctx is None:
ctx = mx.cpu()
arg_params = convert_context(arg_params, ctx)
aux_params = convert_context(aux_params, ctx)
if process:
tests = [k for k in arg_params.keys() if '_test' in k]
for test in tests:
arg_params[test.replace('_test', '')] = arg_params.pop(test)
return arg_params, aux_params