From b129ae9cb338e45debeffae96856b4df4281e8a2 Mon Sep 17 00:00:00 2001 From: 007gzs <007gzs@gmail.com> Date: Wed, 19 May 2021 18:47:30 +0800 Subject: [PATCH] add simplifier in torch2onnx --- recognition/arcface_torch/torch2onnx.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/recognition/arcface_torch/torch2onnx.py b/recognition/arcface_torch/torch2onnx.py index f0bcfac..4e7a68e 100644 --- a/recognition/arcface_torch/torch2onnx.py +++ b/recognition/arcface_torch/torch2onnx.py @@ -3,7 +3,7 @@ import onnx import torch -def convert_onnx(net, path_module, output, opset=11): +def convert_onnx(net, path_module, output, opset=11, simplify=False): assert isinstance(net, torch.nn.Module) img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) img = img.astype(np.float) @@ -18,6 +18,10 @@ def convert_onnx(net, path_module, output, opset=11): model = onnx.load(output) graph = model.graph graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + if simplify: + from onnxsim import simplify + model, check = simplify(model) + assert check, "Simplified ONNX model could not be validated" onnx.save(model, output) @@ -30,6 +34,7 @@ if __name__ == '__main__': parser.add_argument('input', type=str, help='input backbone.pth file or path') parser.add_argument('--output', type=str, default=None, help='output onnx path') parser.add_argument('--network', type=str, default=None, help='backbone network') + parser.add_argument('--simplify', type=bool, default=True, help='onnx simplify') args = parser.parse_args() input_file = args.input if os.path.isdir(input_file): @@ -51,4 +56,4 @@ if __name__ == '__main__': os.makedirs(output_path) assert os.path.isdir(output_path) output_file = os.path.join(output_path, "%s.onnx" % model_name) - convert_onnx(backbone_onnx, input_file, output_file) + convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)