From ddf4f885a7bf6f6dd46e129188547fe6169b6858 Mon Sep 17 00:00:00 2001 From: 007gzs <007gzs@gmail.com> Date: Mon, 17 May 2021 10:05:07 +0800 Subject: [PATCH 1/2] add main to convert pytorch model to onnx --- recognition/arcface_torch/torch2onnx.py | 36 +++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/recognition/arcface_torch/torch2onnx.py b/recognition/arcface_torch/torch2onnx.py index 1b7c322..fada84a 100644 --- a/recognition/arcface_torch/torch2onnx.py +++ b/recognition/arcface_torch/torch2onnx.py @@ -19,3 +19,39 @@ def convert_onnx(net, path_module, output, opset=11): graph = model.graph graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' onnx.save(model, output) + + +if __name__ == '__main__': + import os + import argparse + from backbones import get_model + + parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') + parser.add_argument('input', type=str, help='input backbone.pth file or path') + parser.add_argument('--output', type=str, default=None, help='output onnx path') + parser.add_argument('--network', type=str, default=None, help='backbone network') + parser.add_argument('--fp16', type=bool, default=None, help='backbone network') + args = parser.parse_args() + input_file = args.input + if os.path.isdir(input_file): + input_file = os.path.join(input_file, "backbone.pth") + assert os.path.exists(input_file) + model_name = os.path.basename(os.path.dirname(input_file)).lower() + params = model_name.split("_") + if len(params) >= 3 and params[1] in ('arcface', 'cosface'): + if args.network is None: + args.network = params[2] + if args.fp16 is None: + args.fp16 = len(params) > 3 and params[3] == 'fp16' + assert args.network is not None and args.fp16 is not None + print(args) + backbone_onnx = get_model(args.network, dropout=0, fp16=args.fp16) + + output_path = args.output + if output_path is None: + output_path = os.path.join(os.path.dirname(__file__), 'onnx') + if not os.path.exists(output_path): + os.makedirs(output_path) + assert os.path.isdir(output_path) + output_file = os.path.join(output_path, "%s.onnx" % model_name) + convert_onnx(backbone_onnx, input_file, output_file) From f4f91a0dd8b4c6c38da6a216c34088cc902647b6 Mon Sep 17 00:00:00 2001 From: 007gzs <007gzs@gmail.com> Date: Mon, 17 May 2021 13:49:54 +0800 Subject: [PATCH 2/2] remove fp16 param --- recognition/arcface_torch/torch2onnx.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/recognition/arcface_torch/torch2onnx.py b/recognition/arcface_torch/torch2onnx.py index fada84a..f0bcfac 100644 --- a/recognition/arcface_torch/torch2onnx.py +++ b/recognition/arcface_torch/torch2onnx.py @@ -30,7 +30,6 @@ if __name__ == '__main__': parser.add_argument('input', type=str, help='input backbone.pth file or path') parser.add_argument('--output', type=str, default=None, help='output onnx path') parser.add_argument('--network', type=str, default=None, help='backbone network') - parser.add_argument('--fp16', type=bool, default=None, help='backbone network') args = parser.parse_args() input_file = args.input if os.path.isdir(input_file): @@ -41,11 +40,9 @@ if __name__ == '__main__': if len(params) >= 3 and params[1] in ('arcface', 'cosface'): if args.network is None: args.network = params[2] - if args.fp16 is None: - args.fp16 = len(params) > 3 and params[3] == 'fp16' - assert args.network is not None and args.fp16 is not None + assert args.network is not None print(args) - backbone_onnx = get_model(args.network, dropout=0, fp16=args.fp16) + backbone_onnx = get_model(args.network, dropout=0) output_path = args.output if output_path is None: