mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-14 12:17:55 +00:00
60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
import mxnet as mx
|
|
|
|
|
|
def load_checkpoint(prefix, epoch):
|
|
"""
|
|
Load model checkpoint from file.
|
|
:param prefix: Prefix of model name.
|
|
:param epoch: Epoch number of model we would like to load.
|
|
:return: (arg_params, aux_params)
|
|
arg_params : dict of str to NDArray
|
|
Model parameter, dict of name to NDArray of net's weights.
|
|
aux_params : dict of str to NDArray
|
|
Model parameter, dict of name to NDArray of net's auxiliary states.
|
|
"""
|
|
save_dict = mx.nd.load('%s-%04d.params' % (prefix, epoch))
|
|
arg_params = {}
|
|
aux_params = {}
|
|
for k, v in save_dict.items():
|
|
tp, name = k.split(':', 1)
|
|
if tp == 'arg':
|
|
arg_params[name] = v
|
|
if tp == 'aux':
|
|
aux_params[name] = v
|
|
return arg_params, aux_params
|
|
|
|
|
|
def convert_context(params, ctx):
|
|
"""
|
|
:param params: dict of str to NDArray
|
|
:param ctx: the context to convert to
|
|
:return: dict of str of NDArray with context ctx
|
|
"""
|
|
new_params = dict()
|
|
for k, v in params.items():
|
|
new_params[k] = v.as_in_context(ctx)
|
|
return new_params
|
|
|
|
|
|
def load_param(prefix, epoch, convert=False, ctx=None, process=False):
|
|
"""
|
|
wrapper for load checkpoint
|
|
:param prefix: Prefix of model name.
|
|
:param epoch: Epoch number of model we would like to load.
|
|
:param convert: reference model should be converted to GPU NDArray first
|
|
:param ctx: if convert then ctx must be designated.
|
|
:param process: model should drop any test
|
|
:return: (arg_params, aux_params)
|
|
"""
|
|
arg_params, aux_params = load_checkpoint(prefix, epoch)
|
|
if convert:
|
|
if ctx is None:
|
|
ctx = mx.cpu()
|
|
arg_params = convert_context(arg_params, ctx)
|
|
aux_params = convert_context(aux_params, ctx)
|
|
if process:
|
|
tests = [k for k in arg_params.keys() if '_test' in k]
|
|
for test in tests:
|
|
arg_params[test.replace('_test', '')] = arg_params.pop(test)
|
|
return arg_params, aux_params
|