Files
insightface/detection/scrfd/tools/scrfd2onnx.py
2021-10-10 13:33:50 +03:00

202 lines
6.3 KiB
Python
Executable File

import argparse
import os.path as osp
import numpy as np
import onnx
import os
#import onnxruntime as rt
import torch
from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model,
preprocess_example_input)
#from mmdet.models import build
def pytorch2onnx(config_path,
checkpoint_path,
input_img,
input_shape,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False,
simplify = True,
dynamic = True,
normalize_cfg=None,
dataset='coco',
test_img=None):
input_config = {
'input_shape': input_shape,
'input_path': input_img,
'normalize_cfg': normalize_cfg
}
checkpoint = torch.load(checkpoint_path, map_location='cpu')
tmp_ckpt_file = None
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
tmp_ckpt_file = checkpoint_path+"_slim.pth"
torch.save(checkpoint, tmp_ckpt_file)
print('remove optimizer params and save to', tmp_ckpt_file)
checkpoint_path = tmp_ckpt_file
model, tensor_data = generate_inputs_and_wrap_model(
config_path, checkpoint_path, input_config)
if tmp_ckpt_file is not None:
os.remove(tmp_ckpt_file)
if simplify or dynamic:
ori_output_file = output_file.split('.')[0]+"_ori.onnx"
else:
ori_output_file = output_file
# Define input and outputs names, which are required to properly define
# dynamic axes
input_names = ['input.1']
output_names = ['score_8', 'score_16', 'score_32',
'bbox_8', 'bbox_16', 'bbox_32',
]
# If model graph contains keypoints strides add keypoints to outputs
if 'stride_kps' in str(model):
output_names += ['kps_8', 'kps_16', 'kps_32']
# Define dynamic axes for export
dynamic_axes = None
if dynamic:
dynamic_axes = {out: {0: '?', 1: '?'} for out in output_names}
dynamic_axes[input_names[0]] = {
0: '?',
2: '?',
3: '?'
}
torch.onnx.export(
model,
tensor_data,
ori_output_file,
keep_initializers_as_inputs=False,
verbose=False,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version)
if simplify or dynamic:
model = onnx.load(ori_output_file)
if simplify:
from onnxsim import simplify
#print(model.graph.input[0])
if dynamic:
input_shapes = {model.graph.input[0].name : list(input_shape)}
model, check = simplify(model, input_shapes=input_shapes, dynamic_input_shape=True)
else:
model, check = simplify(model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model, output_file)
os.remove(ori_output_file)
print(f'Successfully exported ONNX model: {output_file}')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMDetection models to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--input-img', type=str, help='Images for input')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--test-img', type=str, default=None, help='Images for test')
parser.add_argument(
'--dataset', type=str, default='coco', help='Dataset name')
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--shape',
type=int,
nargs='+',
#default=[640, 640],
#default=[384, 384],
default=[-1, -1],
help='input image size')
parser.add_argument(
'--mean',
type=float,
nargs='+',
default=[127.5, 127.5, 127.5],
help='mean value used for preprocess input data')
parser.add_argument(
'--std',
type=float,
nargs='+',
default=[128.0, 128.0, 128.0],
help='variance value used for preprocess input data')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
assert args.opset_version == 11, 'MMDet only support opset 11 now'
if not args.input_img:
args.input_img = osp.join(
osp.dirname(__file__), '../tests/data/t1.jpg')
if len(args.shape) == 1:
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (1, 3) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
assert len(args.mean) == 3
assert len(args.std) == 3
simplify = True
dynamic = False
if input_shape[2]<=0 or input_shape[3]<=0:
input_shape = (1,3,640,640)
dynamic = True
#simplify = False
print('set to dynamic input with dummy shape:', input_shape)
normalize_cfg = {'mean': args.mean, 'std': args.std}
if len(args.output_file)==0:
output_dir = osp.join(osp.dirname(__file__), '../onnx')
if not osp.exists(output_dir):
os.makedirs(output_dir)
cfg_name = args.config.split('/')[-1]
pos = cfg_name.rfind('.')
cfg_name = cfg_name[:pos]
if dynamic:
args.output_file = osp.join(output_dir, "%s.onnx"%cfg_name)
else:
args.output_file = osp.join(output_dir, "%s_shape%dx%d.onnx"%(cfg_name,input_shape[2],input_shape[3]))
# convert model to onnx file
pytorch2onnx(
args.config,
args.checkpoint,
args.input_img,
input_shape,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify,
simplify = simplify,
dynamic = dynamic,
normalize_cfg=normalize_cfg,
dataset=args.dataset,
test_img=args.test_img)