import sys import os import argparse import onnx import mxnet as mx from onnx import helper from onnx import TensorProto from onnx import numpy_helper print('mxnet version:', mx.__version__) print('onnx version:', onnx.__version__) assert mx.__version__ >= '1.8', 'mxnet version should >= 1.8' assert onnx.__version__ >= '1.2.1', 'onnx version should >= 1.2.1' import numpy as np from mxnet.contrib import onnx as onnx_mxnet def create_map(graph_member_list): member_map={} for n in graph_member_list: member_map[n.name]=n return member_map parser = argparse.ArgumentParser(description='convert arcface models to onnx') # general parser.add_argument('params', default='./r100a/model-0000.params', help='mxnet params to load.') parser.add_argument('output', default='./r100a.onnx', help='path to write onnx model.') parser.add_argument('--eps', default=1.0e-8, type=float, help='eps for weights.') parser.add_argument('--input-shape', default='3,112,112', help='input shape.') args = parser.parse_args() input_shape = (1,) + tuple( [int(x) for x in args.input_shape.split(',')] ) params_file = args.params pos = params_file.rfind('-') prefix = params_file[:pos] epoch = int(params_file[pos+1:pos+5]) sym_file = prefix + "-symbol.json" assert os.path.exists(sym_file) assert os.path.exists(params_file) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) eps = args.eps arg = {} aux = {} invalid = 0 ac = 0 for k in arg_params: v = arg_params[k] nv = v.asnumpy() #print(k, nv.dtype) nv = nv.astype(np.float32) ac += nv.size invalid += np.count_nonzero(np.abs(nv)