mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-19 15:41:33 +00:00
202 lines
6.3 KiB
Python
Executable File
202 lines
6.3 KiB
Python
Executable File
import argparse
|
|
import os.path as osp
|
|
|
|
import numpy as np
|
|
import onnx
|
|
import os
|
|
#import onnxruntime as rt
|
|
import torch
|
|
|
|
from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model,
|
|
preprocess_example_input)
|
|
|
|
#from mmdet.models import build
|
|
|
|
def pytorch2onnx(config_path,
|
|
checkpoint_path,
|
|
input_img,
|
|
input_shape,
|
|
opset_version=11,
|
|
show=False,
|
|
output_file='tmp.onnx',
|
|
verify=False,
|
|
simplify = True,
|
|
dynamic = True,
|
|
normalize_cfg=None,
|
|
dataset='coco',
|
|
test_img=None):
|
|
|
|
input_config = {
|
|
'input_shape': input_shape,
|
|
'input_path': input_img,
|
|
'normalize_cfg': normalize_cfg
|
|
}
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
tmp_ckpt_file = None
|
|
# remove optimizer for smaller file size
|
|
if 'optimizer' in checkpoint:
|
|
del checkpoint['optimizer']
|
|
tmp_ckpt_file = checkpoint_path+"_slim.pth"
|
|
torch.save(checkpoint, tmp_ckpt_file)
|
|
print('remove optimizer params and save to', tmp_ckpt_file)
|
|
checkpoint_path = tmp_ckpt_file
|
|
|
|
model, tensor_data = generate_inputs_and_wrap_model(
|
|
config_path, checkpoint_path, input_config)
|
|
|
|
if tmp_ckpt_file is not None:
|
|
os.remove(tmp_ckpt_file)
|
|
|
|
if simplify or dynamic:
|
|
ori_output_file = output_file.split('.')[0]+"_ori.onnx"
|
|
else:
|
|
ori_output_file = output_file
|
|
|
|
# Define input and outputs names, which are required to properly define
|
|
# dynamic axes
|
|
input_names = ['input.1']
|
|
output_names = ['score_8', 'score_16', 'score_32',
|
|
'bbox_8', 'bbox_16', 'bbox_32',
|
|
]
|
|
|
|
# If model graph contains keypoints strides add keypoints to outputs
|
|
if 'stride_kps' in str(model):
|
|
output_names += ['kps_8', 'kps_16', 'kps_32']
|
|
|
|
# Define dynamic axes for export
|
|
dynamic_axes = None
|
|
if dynamic:
|
|
dynamic_axes = {out: {0: '?', 1: '?'} for out in output_names}
|
|
dynamic_axes[input_names[0]] = {
|
|
0: '?',
|
|
2: '?',
|
|
3: '?'
|
|
}
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
tensor_data,
|
|
ori_output_file,
|
|
keep_initializers_as_inputs=False,
|
|
verbose=False,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=opset_version)
|
|
|
|
if simplify or dynamic:
|
|
model = onnx.load(ori_output_file)
|
|
if simplify:
|
|
from onnxsim import simplify
|
|
#print(model.graph.input[0])
|
|
if dynamic:
|
|
input_shapes = {model.graph.input[0].name : list(input_shape)}
|
|
model, check = simplify(model, input_shapes=input_shapes, dynamic_input_shape=True)
|
|
else:
|
|
model, check = simplify(model)
|
|
assert check, "Simplified ONNX model could not be validated"
|
|
onnx.save(model, output_file)
|
|
os.remove(ori_output_file)
|
|
|
|
|
|
print(f'Successfully exported ONNX model: {output_file}')
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Convert MMDetection models to ONNX')
|
|
parser.add_argument('config', help='test config file path')
|
|
parser.add_argument('checkpoint', help='checkpoint file')
|
|
parser.add_argument('--input-img', type=str, help='Images for input')
|
|
parser.add_argument('--show', action='store_true', help='show onnx graph')
|
|
parser.add_argument('--output-file', type=str, default='')
|
|
parser.add_argument('--opset-version', type=int, default=11)
|
|
parser.add_argument(
|
|
'--test-img', type=str, default=None, help='Images for test')
|
|
parser.add_argument(
|
|
'--dataset', type=str, default='coco', help='Dataset name')
|
|
parser.add_argument(
|
|
'--verify',
|
|
action='store_true',
|
|
help='verify the onnx model output against pytorch output')
|
|
parser.add_argument(
|
|
'--shape',
|
|
type=int,
|
|
nargs='+',
|
|
#default=[640, 640],
|
|
#default=[384, 384],
|
|
default=[-1, -1],
|
|
help='input image size')
|
|
parser.add_argument(
|
|
'--mean',
|
|
type=float,
|
|
nargs='+',
|
|
default=[127.5, 127.5, 127.5],
|
|
help='mean value used for preprocess input data')
|
|
parser.add_argument(
|
|
'--std',
|
|
type=float,
|
|
nargs='+',
|
|
default=[128.0, 128.0, 128.0],
|
|
help='variance value used for preprocess input data')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
assert args.opset_version == 11, 'MMDet only support opset 11 now'
|
|
|
|
if not args.input_img:
|
|
args.input_img = osp.join(
|
|
osp.dirname(__file__), '../tests/data/t1.jpg')
|
|
|
|
if len(args.shape) == 1:
|
|
input_shape = (1, 3, args.shape[0], args.shape[0])
|
|
elif len(args.shape) == 2:
|
|
input_shape = (1, 3) + tuple(args.shape)
|
|
else:
|
|
raise ValueError('invalid input shape')
|
|
|
|
assert len(args.mean) == 3
|
|
assert len(args.std) == 3
|
|
|
|
simplify = True
|
|
dynamic = False
|
|
if input_shape[2]<=0 or input_shape[3]<=0:
|
|
input_shape = (1,3,640,640)
|
|
dynamic = True
|
|
#simplify = False
|
|
print('set to dynamic input with dummy shape:', input_shape)
|
|
|
|
normalize_cfg = {'mean': args.mean, 'std': args.std}
|
|
|
|
if len(args.output_file)==0:
|
|
output_dir = osp.join(osp.dirname(__file__), '../onnx')
|
|
if not osp.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
cfg_name = args.config.split('/')[-1]
|
|
pos = cfg_name.rfind('.')
|
|
cfg_name = cfg_name[:pos]
|
|
if dynamic:
|
|
args.output_file = osp.join(output_dir, "%s.onnx"%cfg_name)
|
|
else:
|
|
args.output_file = osp.join(output_dir, "%s_shape%dx%d.onnx"%(cfg_name,input_shape[2],input_shape[3]))
|
|
|
|
# convert model to onnx file
|
|
pytorch2onnx(
|
|
args.config,
|
|
args.checkpoint,
|
|
args.input_img,
|
|
input_shape,
|
|
opset_version=args.opset_version,
|
|
show=args.show,
|
|
output_file=args.output_file,
|
|
verify=args.verify,
|
|
simplify = simplify,
|
|
dynamic = dynamic,
|
|
normalize_cfg=normalize_cfg,
|
|
dataset=args.dataset,
|
|
test_img=args.test_img)
|