import sys import os import argparse import onnx import json import mxnet as mx from onnx import helper from onnx import TensorProto from onnx import numpy_helper import onnxruntime import cv2 print('mxnet version:', mx.__version__) print('onnx version:', onnx.__version__) assert mx.__version__ >= '1.8', 'mxnet version should >= 1.8' assert onnx.__version__ >= '1.2.1', 'onnx version should >= 1.2.1' import numpy as np from mxnet.contrib import onnx as onnx_mxnet def create_map(graph_member_list): member_map={} for n in graph_member_list: member_map[n.name]=n return member_map parser = argparse.ArgumentParser(description='convert mxnet model to onnx') # general parser.add_argument('params', default='./r100a/model-0000.params', help='mxnet params to load.') parser.add_argument('output', default='./r100a.onnx', help='path to write onnx model.') parser.add_argument('--eps', default=1.0e-8, type=float, help='eps for weights.') parser.add_argument('--input-shape', default='3,112,112', help='input shape.') parser.add_argument('--check', action='store_true') parser.add_argument('--input-mean', default=0.0, type=float, help='input mean for checking.') parser.add_argument('--input-std', default=1.0, type=float, help='input std for checking.') args = parser.parse_args() input_shape = (1,) + tuple( [int(x) for x in args.input_shape.split(',')] ) params_file = args.params pos = params_file.rfind('-') prefix = params_file[:pos] epoch = int(params_file[pos+1:pos+5]) sym_file = prefix + "-symbol.json" assert os.path.exists(sym_file) assert os.path.exists(params_file) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) nodes = json.loads(sym.tojson())['nodes'] bn_fixgamma_list = [] for nodeid, node in enumerate(nodes): if node['op'] == 'BatchNorm': attr = node['attrs'] fix_gamma = False if attr is not None and 'fix_gamma' in attr: if str(attr['fix_gamma']).lower()=='true': fix_gamma = True if fix_gamma: bn_fixgamma_list.append(node['name']) #print(node, fix_gamma) print('fixgamma list:', bn_fixgamma_list) layer = None #layer = 'conv_2_dw_relu' #for debug if layer is not None: all_layers = sym.get_internals() sym = all_layers[layer + '_output'] eps = args.eps arg = {} aux = {} invalid = 0 ac = 0 for k in arg_params: v = arg_params[k] nv = v.asnumpy() nv = nv.astype(np.float32) #print(k, nv.shape) if k.endswith('_gamma'): bnname = k[:-6] if bnname in bn_fixgamma_list: nv[:] = 1.0 ac += nv.size invalid += np.count_nonzero(np.abs(nv)