mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
add simplifier in torch2onnx
This commit is contained in:
@@ -3,7 +3,7 @@ import onnx
|
||||
import torch
|
||||
|
||||
|
||||
def convert_onnx(net, path_module, output, opset=11):
|
||||
def convert_onnx(net, path_module, output, opset=11, simplify=False):
|
||||
assert isinstance(net, torch.nn.Module)
|
||||
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
|
||||
img = img.astype(np.float)
|
||||
@@ -18,6 +18,10 @@ def convert_onnx(net, path_module, output, opset=11):
|
||||
model = onnx.load(output)
|
||||
graph = model.graph
|
||||
graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
|
||||
if simplify:
|
||||
from onnxsim import simplify
|
||||
model, check = simplify(model)
|
||||
assert check, "Simplified ONNX model could not be validated"
|
||||
onnx.save(model, output)
|
||||
|
||||
|
||||
@@ -30,6 +34,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('input', type=str, help='input backbone.pth file or path')
|
||||
parser.add_argument('--output', type=str, default=None, help='output onnx path')
|
||||
parser.add_argument('--network', type=str, default=None, help='backbone network')
|
||||
parser.add_argument('--simplify', type=bool, default=True, help='onnx simplify')
|
||||
args = parser.parse_args()
|
||||
input_file = args.input
|
||||
if os.path.isdir(input_file):
|
||||
@@ -51,4 +56,4 @@ if __name__ == '__main__':
|
||||
os.makedirs(output_path)
|
||||
assert os.path.isdir(output_path)
|
||||
output_file = os.path.join(output_path, "%s.onnx" % model_name)
|
||||
convert_onnx(backbone_onnx, input_file, output_file)
|
||||
convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
|
||||
|
||||
Reference in New Issue
Block a user