mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
17 lines
617 B
Python
17 lines
617 B
Python
import mxnet as mx
|
|
|
|
|
|
def do_checkpoint(prefix, means, stds):
|
|
def _callback(iter_no, sym, arg, aux):
|
|
if 'bbox_pred_weight' in arg:
|
|
arg['bbox_pred_weight_test'] = (arg['bbox_pred_weight'].T *
|
|
mx.nd.array(stds)).T
|
|
arg['bbox_pred_bias_test'] = arg['bbox_pred_bias'] * mx.nd.array(
|
|
stds) + mx.nd.array(means)
|
|
mx.model.save_checkpoint(prefix, iter_no + 1, sym, arg, aux)
|
|
if 'bbox_pred_weight' in arg:
|
|
arg.pop('bbox_pred_weight_test')
|
|
arg.pop('bbox_pred_bias_test')
|
|
|
|
return _callback
|