add simplifier in torch2onnx

This commit is contained in:
007gzs
2021-05-19 18:47:30 +08:00
committed by GitHub
parent f30826a70f
commit b129ae9cb3

View File

@@ -3,7 +3,7 @@ import onnx
import torch
def convert_onnx(net, path_module, output, opset=11):
def convert_onnx(net, path_module, output, opset=11, simplify=False):
assert isinstance(net, torch.nn.Module)
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
img = img.astype(np.float)
@@ -18,6 +18,10 @@ def convert_onnx(net, path_module, output, opset=11):
model = onnx.load(output)
graph = model.graph
graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
if simplify:
from onnxsim import simplify
model, check = simplify(model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model, output)
@@ -30,6 +34,7 @@ if __name__ == '__main__':
parser.add_argument('input', type=str, help='input backbone.pth file or path')
parser.add_argument('--output', type=str, default=None, help='output onnx path')
parser.add_argument('--network', type=str, default=None, help='backbone network')
parser.add_argument('--simplify', type=bool, default=True, help='onnx simplify')
args = parser.parse_args()
input_file = args.input
if os.path.isdir(input_file):
@@ -51,4 +56,4 @@ if __name__ == '__main__':
os.makedirs(output_path)
assert os.path.isdir(output_path)
output_file = os.path.join(output_path, "%s.onnx" % model_name)
convert_onnx(backbone_onnx, input_file, output_file)
convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)