Files
insightface/detection/retinaface/rcnn/core/callback.py
2021-06-19 23:57:15 +08:00

17 lines
617 B
Python

import mxnet as mx
def do_checkpoint(prefix, means, stds):
def _callback(iter_no, sym, arg, aux):
if 'bbox_pred_weight' in arg:
arg['bbox_pred_weight_test'] = (arg['bbox_pred_weight'].T *
mx.nd.array(stds)).T
arg['bbox_pred_bias_test'] = arg['bbox_pred_bias'] * mx.nd.array(
stds) + mx.nd.array(means)
mx.model.save_checkpoint(prefix, iter_no + 1, sym, arg, aux)
if 'bbox_pred_weight' in arg:
arg.pop('bbox_pred_weight_test')
arg.pop('bbox_pred_bias_test')
return _callback