From cdeca2032ca47da8ae45928fff9e95db63bb8bbb Mon Sep 17 00:00:00 2001 From: nttstar Date: Wed, 30 Dec 2020 18:07:50 +0800 Subject: [PATCH] add onnx2caffe tool, for detection models --- tools/onnx2caffe/LICENSE | 21 + tools/onnx2caffe/README.md | 37 ++ tools/onnx2caffe/convertCaffe.py | 109 ++++ tools/onnx2caffe/onnx2caffe/__init__.py | 0 tools/onnx2caffe/onnx2caffe/_error_utils.py | 64 +++ tools/onnx2caffe/onnx2caffe/_graph.py | 225 ++++++++ tools/onnx2caffe/onnx2caffe/_operators.py | 463 +++++++++++++++++ tools/onnx2caffe/onnx2caffe/_transformers.py | 520 +++++++++++++++++++ tools/onnx2caffe/onnx2caffe/_weightloader.py | 155 ++++++ 9 files changed, 1594 insertions(+) create mode 100644 tools/onnx2caffe/LICENSE create mode 100644 tools/onnx2caffe/README.md create mode 100644 tools/onnx2caffe/convertCaffe.py create mode 100644 tools/onnx2caffe/onnx2caffe/__init__.py create mode 100644 tools/onnx2caffe/onnx2caffe/_error_utils.py create mode 100644 tools/onnx2caffe/onnx2caffe/_graph.py create mode 100644 tools/onnx2caffe/onnx2caffe/_operators.py create mode 100644 tools/onnx2caffe/onnx2caffe/_transformers.py create mode 100644 tools/onnx2caffe/onnx2caffe/_weightloader.py diff --git a/tools/onnx2caffe/LICENSE b/tools/onnx2caffe/LICENSE new file mode 100644 index 0000000..879e87b --- /dev/null +++ b/tools/onnx2caffe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 MTlab, Meitu Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/tools/onnx2caffe/README.md b/tools/onnx2caffe/README.md new file mode 100644 index 0000000..02b8238 --- /dev/null +++ b/tools/onnx2caffe/README.md @@ -0,0 +1,37 @@ +# Convert ONNX to Caffe + +This tool is modified from [onnx2caffe](https://github.com/MTlab/onnx2caffe) by MTlab. + +We added some OPs to support one-stage mmdetection models. + +### Dependencies +* pycaffe (with builtin Upsample and Permute layers) +* onnx + + +### How to use +To convert onnx model to caffe: +``` +python convertCaffe.py ./model/mmdet.onnx ./model/a.prototxt ./model/a.caffemodel +``` + +### Current support operation +* Conv +* ConvTranspose +* BatchNormalization +* MaxPool +* AveragePool +* Relu +* Sigmoid +* Dropout +* Gemm (InnerProduct only) +* Add +* Mul +* Reshape +* Upsample +* Concat +* Flatten +* **Resize** +* **Permute** +* **Scale** + diff --git a/tools/onnx2caffe/convertCaffe.py b/tools/onnx2caffe/convertCaffe.py new file mode 100644 index 0000000..76a6111 --- /dev/null +++ b/tools/onnx2caffe/convertCaffe.py @@ -0,0 +1,109 @@ +#from __future__ import print_function +import sys +import caffe +import onnx +import numpy as np +from caffe.proto import caffe_pb2 +caffe.set_mode_cpu() +from onnx2caffe._transformers import ConvAddFuser,ConstantsToInitializers +from onnx2caffe._graph import Graph + +import onnx2caffe._operators as cvt +import onnx2caffe._weightloader as wlr +from onnx2caffe._error_utils import ErrorHandling +from collections import OrderedDict +from onnx import shape_inference +import importlib + +transformers = [ + ConstantsToInitializers(), + ConvAddFuser(), +] + +def convertToCaffe(graph, prototxt_save_path, caffe_model_save_path): + + exist_edges = [] + layers = [] + exist_nodes = [] + err = ErrorHandling() + for i in graph.inputs: + edge_name = i[0] + input_layer = cvt.make_input(i) + layers.append(input_layer) + exist_edges.append(i[0]) + graph.channel_dims[edge_name] = graph.shape_dict[edge_name][1] + + + for id, node in enumerate(graph.nodes): + node_name = node.name + op_type = node.op_type + inputs = node.inputs + inputs_tensor = node.input_tensors + input_non_exist_flag = False + + for inp in inputs: + if inp not in exist_edges and inp not in inputs_tensor: + input_non_exist_flag = True + break + if input_non_exist_flag: + continue + + if op_type not in cvt._ONNX_NODE_REGISTRY: + err.unsupported_op(node) + continue + converter_fn = cvt._ONNX_NODE_REGISTRY[op_type] + layer = converter_fn(node,graph,err) + if type(layer)==tuple: + for l in layer: + layers.append(l) + else: + layers.append(layer) + outs = node.outputs + for out in outs: + exist_edges.append(out) + + net = caffe_pb2.NetParameter() + for id,layer in enumerate(layers): + layers[id] = layer._to_proto() + net.layer.extend(layers) + + with open(prototxt_save_path, 'w') as f: + print(net,file=f) + + caffe.set_mode_cpu() + deploy = prototxt_save_path + net = caffe.Net(deploy, + caffe.TEST) + + for id, node in enumerate(graph.nodes): + node_name = node.name + op_type = node.op_type + inputs = node.inputs + inputs_tensor = node.input_tensors + input_non_exist_flag = False + if op_type not in wlr._ONNX_NODE_REGISTRY: + err.unsupported_op(node) + continue + converter_fn = wlr._ONNX_NODE_REGISTRY[op_type] + converter_fn(net, node, graph, err) + + net.save(caffe_model_save_path) + return net + +def getGraph(onnx_path): + model = onnx.load(onnx_path) + model = shape_inference.infer_shapes(model) + model_graph = model.graph + graph = Graph.from_onnx(model_graph) + graph = graph.transformed(transformers) + graph.channel_dims = {} + + return graph + +if __name__ == "__main__": + onnx_path = sys.argv[1] + prototxt_path = sys.argv[2] + caffemodel_path = sys.argv[3] + graph = getGraph(onnx_path) + convertToCaffe(graph, prototxt_path, caffemodel_path) + diff --git a/tools/onnx2caffe/onnx2caffe/__init__.py b/tools/onnx2caffe/onnx2caffe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/onnx2caffe/onnx2caffe/_error_utils.py b/tools/onnx2caffe/onnx2caffe/_error_utils.py new file mode 100644 index 0000000..46f34c0 --- /dev/null +++ b/tools/onnx2caffe/onnx2caffe/_error_utils.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import Dict, Text, Any, Callable +from ._graph import Node, Graph + +class ErrorHandling(object): + ''' + To handle errors and addition of custom layers + ''' + + def __init__(self, + add_custom_layers = False, # type: bool + custom_conversion_functions = dict(), # type: Dict[Text, Any] + custom_layer_nodes = [], # type : List[Node] + ): + # type: (...) -> None + self.add_custom_layers = add_custom_layers + self.custom_conversion_functions = custom_conversion_functions + self.custom_layer_nodes = custom_layer_nodes + + + def unsupported_op(self, + node, # type: Node + ): + # type: (...) -> Callable[[Any, Node, Graph, ErrorHandling], None] + ''' + Either raise an error for an unsupported op type or return custom layer add function + ''' + if self.add_custom_layers: + from ._operators import _convert_custom + return _convert_custom + else: + raise TypeError( + "ONNX node of type {} is not supported.\n".format(node.op_type,) + ) + + + def unsupported_op_configuration(self, + node, # type: Node + err_message, # type: Text + ): + raise TypeError( + "Error while converting op of type: {}. Error message: {}\n".format(node.op_type, err_message, ) + ) + + + def missing_initializer(self, + node, # type: Node + err_message, # type: Text + ): + # type: (...) -> None + ''' + Missing initializer error + ''' + raise ValueError( + "Missing initializer error in op of type {}, with input name = {}, " + "output name = {}. Error message: {}\n". + format(node.op_type, node.inputs[0], node.outputs[0], err_message) + ) + + + diff --git a/tools/onnx2caffe/onnx2caffe/_graph.py b/tools/onnx2caffe/onnx2caffe/_graph.py new file mode 100644 index 0000000..89e81db --- /dev/null +++ b/tools/onnx2caffe/onnx2caffe/_graph.py @@ -0,0 +1,225 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from onnx import numpy_helper, ValueInfoProto, AttributeProto, GraphProto, NodeProto, TensorProto, TensorShapeProto +from typing import Any, Text, Iterable, List, Dict, Sequence, Optional, Tuple, Union +from typing_extensions import Protocol +import numpy as np + + +class Transformer(Protocol): + def __call__(self, graph): # type: (Graph) -> Graph + pass + + +EdgeInfo = Tuple[Text, Any, TensorShapeProto] +AttributeValue = Any # TODO Union[Sequence[float], Sequence[int], Sequence[Text], Sequence[TensorProto], Sequence[GraphProto]] + +def _input_from_onnx_input(input): # type: (ValueInfoProto) -> EdgeInfo + name = input.name + type = input.type.tensor_type.elem_type + shape = tuple([d.dim_value for d in input.type.tensor_type.shape.dim]) + return (name, type, shape) + + +def _convertAttributeProto(onnx_arg): # type: (AttributeProto) -> AttributeValue + """ + Convert an ONNX AttributeProto into an appropriate Python object + for the type. + NB: Tensor attribute gets returned as numpy array + """ + if onnx_arg.HasField('f'): + return onnx_arg.f + elif onnx_arg.HasField('i'): + return onnx_arg.i + elif onnx_arg.HasField('s'): + return onnx_arg.s + elif onnx_arg.HasField('t'): + return numpy_helper.to_array(onnx_arg.t) + elif len(onnx_arg.floats): + return list(onnx_arg.floats) + elif len(onnx_arg.ints): + return list(onnx_arg.ints) + elif len(onnx_arg.strings): + return list(onnx_arg.strings) + else: + raise ValueError("Unsupported ONNX attribute: {}".format(onnx_arg)) + + +class Attributes(Dict[Text, Any]): + @staticmethod + def from_onnx(args): # type: (Iterable[AttributeProto]) -> Attributes + d = Attributes() + for arg in args: + d[arg.name] = _convertAttributeProto(arg) + return d + + +class Node(object): + def __init__(self, + name, # type: Optional[Text] + op_type, # type: Text + attrs, # type: Dict[Text, AttributeValue] + inputs, # type: List[Text] + outputs, # type: List[Text] + ): + # type: (...) -> None + self.name = name + self.op_type = op_type + self.attrs = attrs + self.inputs = inputs + self.outputs = outputs + self.input_tensors = {} # type: Dict[Text, np._ArrayLike[Any]] + self.parents = [] # type: List[Node] + self.children = [] # type: List[Node] + self.metadata = {} # type: Dict[Any, Any] + + def add_parent(self, parent_node): # type: (Node) -> None + assert parent_node not in self.parents + self.parents.append(parent_node) + if self not in parent_node.children: + parent_node.children.append(self) + + def add_child(self, child_node): # type: (Node) -> None + assert child_node not in self.children + self.children.append(child_node) + if self not in child_node.parents: + child_node.parents.append(self) + + def get_only_parent(self): # type: () -> Node + if len(self.parents) != 1: + raise ValueError('Node ({}) expected to have 1 parent. Found {}.' + .format(self, len(self.parents))) + return self.parents[0] + + @staticmethod + def from_onnx(node): # type: (NodeProto) -> Node + attrs = Attributes.from_onnx(node.attribute) + name = Text(node.name) + if len(name) == 0: + name = "_".join(node.output) + return Node( + name, node.op_type, attrs, list(node.input), list(node.output) + ) + + +class Graph(object): + def __init__(self, + nodes, # type: List[Node] + inputs, # type: List[EdgeInfo] + outputs, # type: List[EdgeInfo] + shape_dict, # type: Dict[Text,Tuple[int,...]] + ): + # type: (...) -> None + self.nodes = nodes + self.inputs = inputs + self.outputs = outputs + self.shape_dict = shape_dict # data blob name to its shape + + # data blob name to the list of op types it feeds into + self.blob_to_op_type = {} # type: Dict[Text, List[Text]] + # data blob name to the op_type that generates it + self.blob_from_op_type = {} # type: Dict[Text, Text] + + for node_ in nodes: + for input_ in node_.inputs: + if input_ in self.blob_to_op_type: + self.blob_to_op_type[input_].append(node_.op_type) + else: + self.blob_to_op_type[input_] = [node_.op_type] + for output_ in node_.outputs: + if output_ in self.blob_from_op_type: + raise ValueError("Data blob: %s, is generated by more than 1 op" %(output_)) + self.blob_from_op_type[output_] = node_.op_type + + + def transformed(self, transformers): # type: (Iterable[Transformer]) -> Graph + graph = self + for transformer in transformers: + graph = transformer(graph) + return graph + + def has_edge_name(self, name): # type: (Text) -> bool + ''' + Check if name is already used for graph inputs/outputs or for nodes + inputs/outputs + ''' + names = set() + for input in self.inputs: + names.add(input[0]) + for output in self.outputs: + names.add(output[0]) + for node in self.nodes: + names.update(node.inputs) + names.update(node.outputs) + return name in names + + def get_unique_edge_name(self, name): # type: (Text) -> Text + n_ = name + i = 0 + while self.has_edge_name(n_): + n_ = "{}_{}".format(name, i) + i += 1 + return n_ + + @staticmethod + def from_onnx(graph): # type: (GraphProto) -> Graph + input_tensors = { + t.name: numpy_helper.to_array(t) for t in graph.initializer + } + nodes_ = [] + nodes_by_input = {} # type: Dict[Text, List[Node]] + nodes_by_output = {} + for node in graph.node: + node_ = Node.from_onnx(node) + for input_ in node_.inputs: + if input_ in input_tensors: + node_.input_tensors[input_] = input_tensors[input_] + else: + if input_ in nodes_by_input: + input_nodes = nodes_by_input[input_] + else: + input_nodes = [] + nodes_by_input[input_] = input_nodes + input_nodes.append(node_) + for output_ in node_.outputs: + nodes_by_output[output_] = node_ + nodes_.append(node_) + + inputs = [] + for i in graph.input: + if i.name not in input_tensors: + inputs.append(_input_from_onnx_input(i)) + + outputs = [] + for o in graph.output: + outputs.append(_input_from_onnx_input(o)) + + for node_ in nodes_: + for input_ in node_.inputs: + if input_ in nodes_by_output: + node_.parents.append(nodes_by_output[input_]) + for output_ in node_.outputs: + if output_ in nodes_by_input: + node_.children.extend(nodes_by_input[output_]) + + # Dictionary to hold the "value_info" field from ONNX graph + shape_dict = {} # type: Dict[Text,Tuple[int,...]] + + def extract_value_info(shape_dict, # type: Dict[Text,Tuple[int,...]] + value_info, # type: ValueInfoProto[...] + ): + # type: (...) -> None + shape_dict[value_info.name] = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) + + for value_info in graph.value_info: + extract_value_info(shape_dict, value_info) + for value_info in graph.input: + extract_value_info(shape_dict, value_info) + for value_info in graph.output: + extract_value_info(shape_dict, value_info) + + + return Graph(nodes_, inputs, outputs, shape_dict) diff --git a/tools/onnx2caffe/onnx2caffe/_operators.py b/tools/onnx2caffe/onnx2caffe/_operators.py new file mode 100644 index 0000000..ea53b30 --- /dev/null +++ b/tools/onnx2caffe/onnx2caffe/_operators.py @@ -0,0 +1,463 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from caffe import params as P +import math +import numpy as np +from ._graph import Node, Graph +from MyCaffe import Function as myf + +def _compare(a, b, encoding="utf8"): #type: (Text, Text, Text) -> bool + if isinstance(a, bytes): + a = a.decode(encoding) + if isinstance(b, bytes): + b = b.decode(encoding) + return a == b + +def make_input(input): + name = input[0] + output = input[0] + output = [output] + shape = input[2] + shape = list(shape) + input_layer = myf("Input", name, [], output, input_param=dict(shape=dict(dim=shape))) + return input_layer + +def _convert_conv(node, graph, err): + weight_name = node.inputs[1] + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + node_name = node.name + W = None + if weight_name in node.input_tensors: + W = node.input_tensors[weight_name] + else: + err.missing_initializer(node, + "Weight tensor: {} not found in the graph initializer".format(weight_name,)) + is_deconv = False + if node.op_type.endswith("Transpose"): + is_deconv = True + bias_flag = False + bias = None + if len(node.inputs) > 2: + bias = node.input_tensors[node.inputs[2]] + bias_flag = True + dilations = node.attrs.get("dilations", [1, 1]) + # groups = 1 + groups = node.attrs.get("group", 1) + kernel_shape = node.attrs["kernel_shape"] + pads = node.attrs.get("pads", [0, 0, 0, 0]) + strides = node.attrs["strides"] + + layer = myf("Convolution", node_name, [input_name], [output_name], + kernel_h = kernel_shape[0],kernel_w = kernel_shape[1], + stride_h=strides[0], stride_w = strides[1], group = groups, + pad_h = pads[0], pad_w = pads[1], + num_output=W.shape[0], dilation = dilations[0], bias_term = bias_flag) + + graph.channel_dims[output_name] = W.shape[0] + return layer + +def _convert_relu(node,graph,err): + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + name = str(node.name) + + if input_name==output_name: + inplace = True + else: + inplace = False + + layer = myf("ReLU",name,[input_name],[output_name],in_place=inplace) + # l_top_relu1 = L.ReLU(l_bottom, name=name, in_place=True) + + graph.channel_dims[output_name] = graph.channel_dims[input_name] + + return layer + +def _convert_sigmoid(node,graph,err): + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + name = str(node.name) + + if input_name==output_name: + inplace = True + else: + inplace = False + + layer = myf("Sigmoid",name,[input_name],[output_name],in_place=inplace) + # l_top_relu1 = L.ReLU(l_bottom, name=name, in_place=True) + + graph.channel_dims[output_name] = graph.channel_dims[input_name] + + return layer + +def _convert_BatchNorm(node,graph,err): + epsilon = node.attrs.get("epsilon", 1e-5) + scale = node.input_tensors[node.inputs[1]] + bias = node.input_tensors[node.inputs[2]] + mean = node.input_tensors[node.inputs[3]] + var = node.input_tensors[node.inputs[4]] + node_name = node.name + + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + + if input_name==output_name: + inplace = True + else: + inplace = False + + bn_layer = myf("BatchNorm", node_name+"_bn",[input_name],[output_name],eps = epsilon, use_global_stats = True, in_place=inplace) + scale_layer = myf("Scale", node_name, [output_name],[output_name],in_place=True,bias_term=True) + + graph.channel_dims[output_name] = graph.channel_dims[input_name] + + return bn_layer,scale_layer + +def _convert_Add(node,graph,err): + input_name_list = [str(i) for i in node.inputs] + output_name = str(node.outputs[0]) + node_name = node.name + + max_dim = 0 + for name in input_name_list: + if graph.channel_dims[name]>max_dim: + max_dim = graph.channel_dims[name] + + if 'broadcast' in node.attrs: + if node.attrs['broadcast'] == 1: + input_node_number = len(input_name_list) + if input_node_number !=2: + return err.unsupported_op_configuration(node, "Broadcast Add must has 2 input, not {}".format(input_node_number)) + axis = node.attrs['axis'] + flat_layer = myf("Flatten",node_name+'_flat',[input_name_list[1]],[output_name+'_flat']) + layer = myf("Bias", node_name, [input_name_list[0],output_name+'_flat'], [output_name], axis = axis) + # layer = myf("Bias", node_name, input_name_list, [output_name], bias_term = False, axis = axis) + graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] + return flat_layer,layer + + layer = myf("Eltwise",node_name,input_name_list,[output_name],operation=P.Eltwise.SUM) + graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] + return layer + +def _convert_Mul(node,graph,err): + input_name_list = [str(i) for i in node.inputs] + output_name = str(node.outputs[0]) + node_name = node.name + print('Mul:', node.name, node.attrs, input_name_list, output_name) + if len(node.attrs)==0: + assert len(node.input_tensors)==1 + assert len(input_name_list)==2 + inp_tensor = node.input_tensors[input_name_list[1]] + scale_value = float(inp_tensor) + print(scale_value) + layer = myf("Scale", node_name, [input_name_list[0]], [output_name], bias_term = False, + scale_param = dict(filler = dict(value=scale_value), bias_term=False)) + return layer + #layer = myf("Reshape", node_name, [input_name], [output_name], reshape_param = dict(shape=dict(dim=list(shape)))) + #print(len(node.input_tensors)) + + # max_dim = 0 + # for name in input_name_list: + # if graph.channel_dims[name]>max_dim: + # max_dim = graph.channel_dims[name] + + if 'broadcast' in node.attrs: + if node.attrs['broadcast'] == 1: + input_node_number = len(input_name_list) + if input_node_number !=2: + return err.unsupported_op_configuration(node, "Broadcast Mul must has 2 input, not {}".format(input_node_number)) + axis = node.attrs['axis'] + flat_layer = myf("Flatten",node_name+'_flat',[input_name_list[1]],[output_name+'_flat']) + layer = myf("Scale", node_name, [input_name_list[0],output_name+'_flat'], [output_name], bias_term = False, axis = axis) + graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] + return flat_layer,layer + + layer = myf("Eltwise",node_name,input_name_list,[output_name],operation=P.Eltwise.PROD) + graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] + return layer + +def _convert_Reshape(node,graph,err): + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + if len(node.inputs)==1: + shape = tuple(node.attrs.get('shape', ())) + else: + shape = tuple(node.input_tensors[node.inputs[1]]) + # if shape == (): + + #print('reshape to', shape) + + if input_name==output_name: + inplace = True + else: + inplace = False + + graph.channel_dims[output_name] = shape[1] + layer = myf("Reshape", node_name, [input_name], [output_name], reshape_param = dict(shape=dict(dim=list(shape)))) + return layer + + #if len(shape) == 2: + # layer = myf("Flatten",node_name,[input_name],[output_name],in_place=inplace) + # graph.channel_dims[output_name] = shape[1] + # return layer + #elif len(shape) == 4: + # graph.channel_dims[output_name] = shape[1] + # layer = myf("Reshape", node_name, [input_name], [output_name], reshape_param = dict(shape=dict(dim=list(shape)))) + # return layer + #else: + # return err.unsupported_op_configuration(node, "Reshape dimention number shall be 2 or 4") + +def _convert_Flatten(node,graph,err): + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + # shape = tuple(node.attrs.get('shape', ())) + if input_name==output_name: + inplace = True + else: + inplace = False + layer = myf("Flatten", node_name, [input_name], [output_name], in_place=inplace) + # graph.channel_dims[output_name] = shape[1] + return layer + +def _convert_pool(node,graph,err): + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + if node.op_type.endswith("MaxPool"): + pool_type = P.Pooling.MAX + elif node.op_type.endswith("AveragePool"): + pool_type = P.Pooling.AVE + else: + return err.unsupported_op_configuration(node, "Unsupported pool type") + + kernel_shape = node.attrs["kernel_shape"] + strides = node.attrs.get('strides', [1, 1]) + pads = node.attrs.get('pads', [0, 0, 0, 0]) + + layer = myf("Pooling",node_name,[input_name],[output_name],pooling_param = dict(pool = pool_type, + kernel_h = kernel_shape[0], + kernel_w = kernel_shape[1], + stride_h = strides[0], + stride_w = strides[1], + pad_h = pads[0], + pad_w = pads[1])) + graph.channel_dims[output_name] = graph.channel_dims[input_name] + return layer + +def _convert_dropout(node,graph,err): + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + ratio = node.attrs.get('ratio', 0.5) + layer = myf("Dropout", node_name, [input_name], [output_name], dropout_ratio =ratio) + graph.channel_dims[output_name] = graph.channel_dims[input_name] + return layer + +def _convert_gemm(node,graph,err): + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + weight_name = node.inputs[1] + if weight_name in node.input_tensors: + W = node.input_tensors[weight_name] + else: + err.missing_initializer(node, + "Weight tensor: {} not found in the graph initializer".format(weight_name, )) + return + + if node.attrs["broadcast"] != 1 or node.attrs["transB"] != 1: + return err.unsupported_op_configuration(node,"Gemm is supported only for inner_product layer") + + b = None + bias_flag = False + if len(node.inputs) > 2: + b = node.input_tensors[node.inputs[2]] + + if len(W.shape) != 2 or (b is not None and len(b.shape) != 1): + return err.unsupported_op_configuration(node, "Gemm is supported only for inner_product layer") + if b is not None: + bias_flag = True + if W.shape[0] != b.shape[0]: + return err.unsupported_op_configuration(node, + "Gemm is supported only for inner_product layer") + + layer = myf("InnerProduct",node_name,[input_name],[output_name],num_output = W.shape[0],bias_term = bias_flag) + graph.channel_dims[output_name] = W.shape[0] + + return layer + +def _convert_upsample(node,graph,err): + factor = int(node.attrs["height_scale"]) + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + # input_shape = graph.shape_dict[input_name] + # channels = input_shape[1] + channels = graph.channel_dims[input_name] + pad = int(math.ceil((factor - 1) / 2.)) + # layer = myf("Deconvolution", node_name, [input_name], [output_name], + # kernel_size=2 * factor - factor % 2, + # stride=factor, group=channels, + # pad = pad, num_output=channels, bias_term = False) + mode = node.attrs["mode"] + #https://github.com/pytorch/pytorch/issues/6900 + if mode=="bilinear": + layer = myf("Deconvolution", node_name, [input_name], [output_name], + convolution_param=dict( + num_output=channels, + kernel_size=2 * factor - factor % 2, + stride=factor, + pad=pad, + group=channels, + bias_term=False, + weight_filler=dict(type="bilinear_upsampling") + )) + else: + layer = myf("Deconvolution", node_name, [input_name], [output_name], + convolution_param=dict( + num_output=channels, + kernel_size=factor, + stride=factor, + group=channels, + bias_term=False, + )) + + graph.channel_dims[output_name] = graph.channel_dims[input_name] + return layer + +def _convert_resize(node,graph,err): + #print(node, graph) + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + #print(node.attrs, node_name, input_name, output_name) + layer = myf("Upsample", node_name, [input_name], [output_name], + upsample_param=dict( + scale = 2 + )) + + graph.channel_dims[output_name] = graph.channel_dims[input_name] + return layer + +def _convert_transpose(node,graph,err): + #print(node, graph) + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + #print(node.attrs, node_name, input_name, output_name) + layer = myf("Permute", node_name, [input_name], [output_name], + permute_param=dict( + order = node.attrs['perm'] + )) + + graph.channel_dims[output_name] = graph.channel_dims[input_name] + return layer + +def _convert_softmax(node,graph,err): + #print(node, graph) + node_name = node.name + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + #print(node.attrs, node_name, input_name, output_name) + layer = myf("Softmax", node_name, [input_name], [output_name], + softmax_param=dict( + axis = node.attrs['axis'] + )) + + graph.channel_dims[output_name] = graph.channel_dims[input_name] + return layer + +def _convert_concat(node,graph,err): + node_name = node.name + input_name_list = [str(i) for i in node.inputs] + output_name = str(node.outputs[0]) + axis = node.attrs.get("axis", 1) + + layer = myf('Concat', node_name, input_name_list, [output_name], axis = axis) + if axis == 1: + dim = 0 + for name in input_name_list: + dim+=graph.channel_dims[name] + graph.channel_dims[output_name] = dim + else: + graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] + + return layer + +def _convert_conv_transpose(node,graph,err): + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + node_name = node.name + weight_name = node.inputs[1] + W = None + if weight_name in node.input_tensors: + W = node.input_tensors[weight_name] + else: + err.missing_initializer(node, + "Weight tensor: {} not found in the graph initializer".format(weight_name,)) + bias_flag = False + bias = None + if len(node.inputs) > 2: + bias = node.input_tensors[node.inputs[2]] + bias_flag = True + dilations = node.attrs.get("dilations", [1, 1]) + # groups = 1 + groups = node.attrs.get("group", 1) + kernel_shape = node.attrs["kernel_shape"] + pads = node.attrs.get("pads", [0, 0, 0, 0]) + strides = node.attrs["strides"] + + layer = myf('Deconvolution', node_name, [input_name], [output_name], + convolution_param=dict( + num_output=W.shape[1], + kernel_h=kernel_shape[0],kernel_w=kernel_shape[1], + stride_h=strides[0],stride_w = strides[1], + group=groups, + pad_h=pads[0], pad_w=pads[1], + bias_term=bias_flag, + )) + + graph.channel_dims[output_name] = W.shape[1] + return layer + + # l_top = L.Deconvolution( + # l_bottom, + # name=name, + # convolution_param=dict( + # num_output=W.shape[1], + # kernel_h=kernel_h, + # kernel_w=kernel_w, + # stride_h=stride_h, + # stride_w=stride_w, + # pad_h=pad_h, + # pad_w=pad_w, + # group=groups, + # bias_term=bias_term)) + + + +_ONNX_NODE_REGISTRY = { + "Conv": _convert_conv, + "Relu": _convert_relu, + "BatchNormalization": _convert_BatchNorm, + "Add": _convert_Add, + "Mul": _convert_Mul, + "Reshape": _convert_Reshape, + "MaxPool": _convert_pool, + "AveragePool": _convert_pool, + "Dropout": _convert_dropout, + "Gemm": _convert_gemm, + "Upsample": _convert_upsample, + "Concat": _convert_concat, + "ConvTranspose": _convert_conv_transpose, + "Sigmoid": _convert_sigmoid, + "Flatten": _convert_Flatten, + "Resize": _convert_resize, + "Transpose": _convert_transpose, + "Softmax": _convert_softmax, +} diff --git a/tools/onnx2caffe/onnx2caffe/_transformers.py b/tools/onnx2caffe/onnx2caffe/_transformers.py new file mode 100644 index 0000000..4f8c815 --- /dev/null +++ b/tools/onnx2caffe/onnx2caffe/_transformers.py @@ -0,0 +1,520 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from typing import Sequence, Text, Dict, List +import numpy as np + +from onnx import TensorProto + +from ._graph import Graph, Node + + +class NodesFuser(object): + ''' + An abstract helper for merging nodes + ''' + def __init__(self, + num_nodes, # type: int + ): + # type: (...) -> None + assert num_nodes >= 2, "Algorithm only works if fusing multiple nodes" + self.num_nodes = num_nodes + + def __call__(self, graph): # type: (Graph) -> Graph + nodes = graph.nodes + merged_nodes = {} + for node in nodes: + nodes_window = [] # type: List[Node] + n = node + for _ in range(self.num_nodes - 1): + if len(n.parents) != 1: + # We're only fusing nodes with single parents + break + p = n.get_only_parent() + if len(p.children) != 1: + # We can only fuse a node if its parent's + # value isn't used by any other node. + break + nodes_window.insert(0, n) + n = p + if len(nodes_window) > 0: + # add parent of chained nodes + first = nodes_window[0] + p = first.get_only_parent() + if len(p.children) == 1: + nodes_window.insert(0, p) + if len(nodes_window) != self.num_nodes: + continue + if not self.is_eligible(graph, nodes_window): + continue + merged = self.merge(graph, nodes_window) + first, last = nodes_window[0], nodes_window[-1] + for parent in first.parents: + parent.children.remove(first) + if merged[0] not in parent.children: + parent.add_child(merged[0]) + for child in last.children: + child.parents.remove(last) + if merged[-1] not in child.parents: + child.add_parent(merged[-1]) + for n in nodes_window: + merged_nodes[n.name] = merged + + transformed_nodes = [] + added_merged = [] # type: List[Node] + for node in nodes: + if node.name in merged_nodes: + merged = merged_nodes[node.name] + if merged[0] not in added_merged: + for n in merged: + transformed_nodes.append(n) + added_merged.append(merged[0]) + else: + transformed_nodes.append(node) + return Graph(transformed_nodes, graph.inputs, graph.outputs, graph.shape_dict) + + def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool + '''Returns true if this subset of nodes is eligible for fusion.''' + raise NotImplementedError('Must be implemented by subclass.') + + def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] + '''Merge nodes''' + nodes[0].outputs = nodes[-1].outputs + return [nodes[0]] + + +class ConvAddFuser(NodesFuser): + ''' + Fuses Add layer into parent convolution layer. + ''' + def __init__(self): # type: () -> None + super(ConvAddFuser, self).__init__(2) + + def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool + parent, child = nodes[0], nodes[1] + if parent.op_type != 'Conv': + return False + if child.op_type != 'Add': + return False + if 'broadcast' not in child.attrs: + return False + if 'axis' not in child.attrs: + return False + if parent.inputs[1] not in parent.input_tensors: + return False + if len(parent.inputs) > 2 and parent.inputs[2] not in parent.input_tensors: + return False + if child.inputs[1] not in child.input_tensors: + return False + + broadcast = child.attrs['broadcast'] + if broadcast != 1: + return False + + axis = child.attrs['axis'] + if axis != 1: + return False + + return True + + def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] + parent, child = nodes[0], nodes[1] + output_channels = parent.input_tensors[parent.inputs[1]].shape[0] + if len(parent.inputs) > 2: + bias_input_name = parent.inputs[2] + bias = parent.input_tensors[bias_input_name] + else: + bias_input_name = "{}_bias".format(parent.name,) + parent.inputs.append(bias_input_name) + bias = np.zeros( + (output_channels,), dtype=np.float32 + ) + parent.input_tensors[bias_input_name] = bias + bias = bias + child.input_tensors[child.inputs[1]] + parent.input_tensors[bias_input_name] = bias + parent.outputs = child.outputs + parent.children.remove(child) + child.parents.remove(parent) + return [parent] + + +class BNBroadcastedMulFuser(NodesFuser): + ''' + Fuses Mul into BatchNorm + ''' + def __init__(self): # type: () -> None + super(BNBroadcastedMulFuser, self).__init__(2) + + def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool + parent, child = nodes[0], nodes[1] + if parent.op_type != 'BatchNormalization': + return False + if child.op_type != 'Mul': + return False + if "broadcast" not in child.attrs: + return False + if child.attrs["broadcast"] != 1: + return False + if "axis" not in child.attrs: + return False + if child.attrs["axis"] != 1: + return False + if child.inputs[1] not in child.input_tensors: + return False + if parent.inputs[1] not in parent.input_tensors: + return False + if parent.inputs[2] not in parent.input_tensors: + return False + return True + + def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] + parent, child = nodes[0], nodes[1] + weight = parent.input_tensors[parent.inputs[1]] + bias = parent.input_tensors[parent.inputs[2]] + W = child.input_tensors[child.inputs[1]] + parent.input_tensors[parent.inputs[1]] = np.multiply(weight, W) + parent.input_tensors[parent.inputs[2]] = np.multiply(bias, W) + parent.outputs = child.outputs + parent.children.remove(child) + child.parents.remove(parent) + return [parent] + + +class BNBroadcastedAddFuser(NodesFuser): + ''' + Fuses Add into BatchNorm + ''' + def __init__(self): # type: () -> None + super(BNBroadcastedAddFuser, self).__init__(2) + + def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool + parent, child = nodes[0], nodes[1] + if parent.op_type != 'BatchNormalization': + return False + if child.op_type != 'Add': + return False + if "broadcast" not in child.attrs: + return False + if child.attrs["broadcast"] != 1: + return False + if "axis" not in child.attrs: + return False + if child.attrs["axis"] != 1: + return False + if len(child.inputs) != 2: + return False + if child.inputs[1] not in child.input_tensors: + return False + if parent.inputs[2] not in parent.input_tensors: + return False + return True + + def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] + parent, child = nodes[0], nodes[1] + bias = parent.input_tensors[parent.inputs[2]] + b = child.input_tensors[child.inputs[1]] + parent.input_tensors[parent.inputs[2]] = bias + b + parent.outputs = child.outputs + parent.children.remove(child) + child.parents.remove(parent) + return [parent] + + +class DropoutRemover(NodesFuser): + ''' + Removes Dropout layer + ''' + def __init__(self): # type: () -> None + super(DropoutRemover, self).__init__(2) + + def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool + child = nodes[1] + return child.op_type == "Dropout" + + def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] + parent, child = nodes[0], nodes[1] + parent.children.remove(child) + child.parents.remove(parent) + parent.outputs = child.outputs + return [parent] + + +class ReshapeInitTensorFuser(object): + ''' + Fuses Reshape operator if it is used only to reshape blob in + graph initializer. We can reshape here instead of runtime. + ''' + + def __call__(self, graph): # type: (Graph) -> Graph + nodes = graph.nodes + removed = [] + for node in nodes: + if node.op_type != 'Reshape': + continue + if not (len(node.input_tensors) == 2 or len(node.input_tensors) == 1): + continue + tensor_name = node.inputs[0] + if tensor_name not in node.input_tensors: + continue + if len(node.inputs) > 1: + shape_name = node.inputs[1] + if shape_name not in node.input_tensors: + continue + is_non_constant_parent = False + if len(node.parents) > 0: + for parent in node.parents: + if parent.op_type != 'Constant': + is_non_constant_parent = True + break + if is_non_constant_parent: + continue + + removed.append(node) + output_name = node.outputs[0] + + tensor = node.input_tensors[tensor_name] + if 'shape' in node.attrs: + shape = tuple(node.attrs["shape"]) + else: + shape = node.input_tensors[shape_name] # type: ignore + + # ONNX spec supports setting dimension to '0', in which case + # it should be taken from old dimension. + # This isn't supported in numpy, so don't transform. + # TODO Should we support this case? + if any([s == 0 for s in shape]): + continue + + reshaped_tensor = tensor.reshape(shape) + + for child in node.children: + child.parents.remove(node) + child.input_tensors[output_name] = reshaped_tensor + + transformed_nodes = [node for node in nodes if node not in removed] + return Graph(transformed_nodes, graph.inputs, graph.outputs, graph.shape_dict) + + +class OutputRenamer(object): + ''' + Rename outputs according to mapping + ''' + def __init__(self, + mapping, # type: Dict[Text, Text] + ): + # type: (...) -> None + self.mapping = mapping + + def __call__(self, graph): # type: (Graph) -> Graph + mapping = self.mapping.copy() + nodes = graph.nodes + for node in nodes: + for i in range(len(node.outputs)): + output = node.outputs[i] + if output not in mapping: + continue + node.outputs[i] = mapping[output] + for child in node.children: + for j in range(len(child.inputs)): + input_ = child.inputs[j] + if input_ != output: + continue + child.inputs[j] = mapping[output] + del mapping[output] + if len(mapping) == 0: + break + return graph + + +class PixelShuffleFuser(NodesFuser): + ''' + Fuses 3 operators reshape->transpose->reshape which is equivalent to + pytorch's pixel_shuffle layer + ''' + def __init__(self): # type: () -> None + super(PixelShuffleFuser, self).__init__(3) + self.num_added = 0 + + def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool + if nodes[0].op_type != 'Reshape': + return False + if nodes[1].op_type != 'Transpose': + return False + if nodes[2].op_type != 'Reshape': + return False + if nodes[0].inputs[1] not in nodes[0].input_tensors: + return False + if nodes[2].inputs[1] not in nodes[2].input_tensors: + return False + + shape = nodes[0].input_tensors[nodes[0].inputs[1]] + if len(shape) != 6: + return False + if shape[0] != 1 or shape[2] != shape[3]: + return False + + input_channels = shape[1] + scale_factor = shape[2] + input_height = shape[4] + input_width = shape[5] + + if nodes[1].attrs.get('perm', []) != [0, 1, 4, 2, 5, 3]: + return False + + shape = nodes[2].input_tensors[nodes[2].inputs[1]] + if len(shape) != 4: + return False + + output_channels = shape[1] + output_height = shape[2] + output_width = shape[3] + if input_channels != output_channels: + return False + if (input_height * scale_factor) != output_height: + return False + if (input_width * scale_factor) != output_width: + return False + + return True + + def get_unique_edge_name(self, graph, name): # type: (Graph, Text) -> Text + self.num_added += 1 + return graph.get_unique_edge_name(name + '_' + str(self.num_added)) + + def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] + ''' + Pixel shuffle is implemented using 3 operators: + - Reshape(1, channels, scale, scale, height, width) + - Transpose(0, 1, 4, 2, 5, 3) + - Reshape(1, channels, height * scale, width * scale) + CoreML Reshape and Transpose layers don't support tensors with more + than 4 dimensions. Thus we change above sequence of operators to the + following equivalent sequence: + - Reshape(channels, scale * scale, height, width) + - Transpose(0, 2, 1, 3) + - Reshape(channels * height, scale, scale, width) + - Transpose(0, 1, 3, 2) + - Reshape(1, channels, height * scale, width * scale) + ''' + reshape_1 = nodes[0] + transpose_1 = nodes[1] + transpose_1.children = [] + + shape = reshape_1.input_tensors[reshape_1.inputs[1]] + + channels = shape[1] + scale = shape[2] + height = shape[4] + width = shape[5] + + reshape_1.input_tensors[reshape_1.inputs[1]] = np.asarray([channels, scale * scale, height, width]) + transpose_1.attrs['perm'] = [0, 2, 1, 3] + + reshape_output_name = 'pixel_shuffle_reshape' + transpose_output_name = 'pixel_shuffle_transpose' + + transpose_1.outputs = [ + self.get_unique_edge_name(graph, transpose_output_name) + ] + + shape_name_second_reshape = self.get_unique_edge_name(graph, reshape_output_name) + output_name_second_reshape = self.get_unique_edge_name(graph, reshape_output_name) + reshape_2 = Node( + reshape_output_name, + 'Reshape', + {}, + [transpose_1.outputs[0], shape_name_second_reshape], + [output_name_second_reshape] + ) + reshape_2.input_tensors[shape_name_second_reshape] = np.asarray([channels * height, scale, scale, width]) + transpose_1.add_child(reshape_2) + + transpose_2 = Node( + transpose_output_name, + 'Transpose', + {'perm': [0, 1, 3, 2]}, + reshape_2.outputs, + [self.get_unique_edge_name(graph, transpose_output_name)] + ) + reshape_2.add_child(transpose_2) + + final_reshape = nodes[2] + final_reshape.inputs = [transpose_2.outputs[0], nodes[2].inputs[1]] + final_reshape.parents = [] + transpose_2.add_child(final_reshape) + return [reshape_1, transpose_1, reshape_2, transpose_2, final_reshape] + + +class AddModelInputsOutputs(object): + ''' + Expose hidden states of recurrent layers as model inputs and outputs + ''' + def __call__(self, graph): # type: (Graph) -> Graph + input_names = [str(input_[0]) for input_ in graph.inputs] + output_names = [str(output_[0]) for output_ in graph.outputs] + for node in graph.nodes: + if str(node.op_type) == 'LSTM': + input_h = node.inputs[5] if len(node.inputs) > 5 else node.inputs[0] + '_h_input' + input_c = node.inputs[6] if len(node.inputs) > 6 else node.inputs[0] + '_c_input' + output_h = node.outputs[1] if len(node.outputs) > 1 else node.outputs[0] + '_h_output' + output_c = node.outputs[2] if len(node.outputs) > 2 else node.outputs[0] + '_c_output' + h = node.attrs["hidden_size"] + for input_ in [str(input_h), str(input_c)]: + if input_ not in input_names: + graph.inputs.append(tuple((input_, TensorProto.FLOAT, (h,)))) #type: ignore + if input_ not in graph.blob_to_op_type: + graph.blob_to_op_type[input_] = ['LSTM'] + for output_ in [str(output_h), str(output_c)]: + if output_ not in output_names: + graph.outputs.append(tuple((output_, TensorProto.FLOAT, (h,)))) #type: ignore + graph.blob_from_op_type[output_] = 'LSTM' + return graph + + +class ConstantsToInitializers(object): + ''' + Takes onnx Constant nodes and puts the tensor into graph initializers instead. + ''' + def __call__(self, graph): # type: (Graph) -> Graph + output_names = [str(output_[0]) for output_ in graph.outputs] + remaining_nodes = [] + for node in graph.nodes: + if node.op_type != 'Constant' or node.name in output_names: + remaining_nodes.append(node) + continue + for child in node.children: + child.input_tensors[node.outputs[0]] = node.attrs["value"] + + graph.nodes = remaining_nodes + return graph + + +class ImageScalerRemover(object): + ''' + Removes ImageScaler layer if connected to a model input and single parent child nodes + ''' + + def __call__(self, graph): # type: (Graph) -> Graph + input_names = [str(input_[0]) for input_ in graph.inputs] + nodes_to_be_removed = [] + for node in graph.nodes: + if (node.op_type != 'ImageScaler') or (len(node.parents) != 0) or (node.inputs[0] not in input_names): + continue + is_eligible = True + for child in node.children: + if not (len(child.parents) == 1 and child.inputs[0] == node.outputs[0]): + is_eligible = False + break + child.inputs[0] = node.inputs[0] + child.parents = [] + if not is_eligible: + continue + nodes_to_be_removed.append(node.name) + + transformed_nodes = [] + for node in graph.nodes: + if node.name not in nodes_to_be_removed: + transformed_nodes.append(node) + return Graph(transformed_nodes, graph.inputs, graph.outputs, graph.shape_dict) \ No newline at end of file diff --git a/tools/onnx2caffe/onnx2caffe/_weightloader.py b/tools/onnx2caffe/onnx2caffe/_weightloader.py new file mode 100644 index 0000000..206cffc --- /dev/null +++ b/tools/onnx2caffe/onnx2caffe/_weightloader.py @@ -0,0 +1,155 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +# from caffe import params as P +import numpy as np +from ._graph import Node, Graph + +def _convert_conv(net, node, graph, err): + weight_name = node.inputs[1] + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + node_name = node.name + W = None + if weight_name in node.input_tensors: + W = node.input_tensors[weight_name] + else: + err.missing_initializer(node, + "Weight tensor: {} not found in the graph initializer".format(weight_name,)) + bias_flag = False + bias = None + if len(node.inputs) > 2: + bias = node.input_tensors[node.inputs[2]] + bias_flag = True + # net.params[node_name][0].data = W + # if bias_flag: + # net.params[node_name][1].data = bias + np.copyto(net.params[node_name][0].data,W,casting='same_kind') + if bias_flag: + np.copyto(net.params[node_name][1].data, bias, casting='same_kind') + +def _convert_relu(net, node, graph, err): + pass + +def _convert_sigmoid(net, node, graph, err): + pass + +def _convert_BatchNorm(net, node, graph, err): + scale = node.input_tensors[node.inputs[1]] + bias = node.input_tensors[node.inputs[2]] + mean = node.input_tensors[node.inputs[3]] + var = node.input_tensors[node.inputs[4]] + node_name = node.name + np.copyto(net.params[node_name + '_bn'][0].data, mean, casting='same_kind') + np.copyto(net.params[node_name + '_bn'][1].data, var, casting='same_kind') + net.params[node_name + '_bn'][2].data[...] = 1.0 + np.copyto(net.params[node_name][0].data, scale, casting='same_kind') + np.copyto(net.params[node_name][1].data, bias, casting='same_kind') + # net.params[node_name+'_bn'][1].data = var + # net.params[node_name][0].data = scale + # net.params[node_name][1].data = bias + +def _convert_Add(net, node, graph, err): + pass + +def _convert_Mul(net, node, graph, err): + pass + +def _convert_Reshape(net, node, graph, err): + pass + +def _convert_Flatten(net, node, graph, err): + pass + +def _convert_pool(net, node, graph, err): + pass + +def _convert_dropout(net, node, graph, err): + pass + +def _convert_gemm(net, node, graph, err): + node_name = node.name + weight_name = node.inputs[1] + if weight_name in node.input_tensors: + W = node.input_tensors[weight_name] + else: + err.missing_initializer(node, + "Weight tensor: {} not found in the graph initializer".format(weight_name, )) + if node.attrs["broadcast"] != 1 or node.attrs["transB"] != 1: + return err.unsupported_op_configuration(node, "Gemm is supported only for inner_product layer") + b = None + if len(node.inputs) > 2: + b = node.input_tensors[node.inputs[2]] + if len(W.shape) != 2 or (b is not None and len(b.shape) != 1): + return err.unsupported_op_configuration(node, "Gemm is supported only for inner_product layer") + if b is not None: + if W.shape[0] != b.shape[0]: + return err.unsupported_op_configuration(node, "Gemm is supported only for inner_product layer") + net.params[node_name][0].data[...] = W + net.params[node_name][1].data[...] = b + +def _convert_upsample(net, node, graph, err): + mode = node.attrs["mode"] + node_name = node.name + if mode == "nearest": + caffe_params = net.params[node_name][0].data + weights = np.ones(caffe_params.shape).astype("float32") + np.copyto(net.params[node_name][0].data, weights, casting='same_kind') + # net.params[node_name][0].data[] + +def _convert_resize(net, node, graph, err): + pass + +def _convert_transpose(net, node, graph, err): + pass +def _convert_concat(net, node, graph, err): + pass +def _convert_softmax(net, node, graph, err): + pass + +def _convert_conv_transpose(net, node, graph, err): + weight_name = node.inputs[1] + input_name = str(node.inputs[0]) + output_name = str(node.outputs[0]) + node_name = node.name + W = None + if weight_name in node.input_tensors: + W = node.input_tensors[weight_name] + else: + err.missing_initializer(node, + "Weight tensor: {} not found in the graph initializer".format(weight_name,)) + bias_flag = False + bias = None + if len(node.inputs) > 2: + bias = node.input_tensors[node.inputs[2]] + bias_flag = True + # net.params[node_name][0].data = W + # if bias_flag: + # net.params[node_name][1].data = bias + np.copyto(net.params[node_name][0].data,W,casting='same_kind') + if bias_flag: + np.copyto(net.params[node_name][1].data, bias, casting='same_kind') + +_ONNX_NODE_REGISTRY = { + "Conv": _convert_conv, + "Relu": _convert_relu, + "BatchNormalization": _convert_BatchNorm, + "Add": _convert_Add, + "Mul": _convert_Mul, + "Reshape": _convert_Reshape, + "MaxPool": _convert_pool, + "AveragePool": _convert_pool, + "Dropout": _convert_dropout, + "Gemm": _convert_gemm, + "Upsample": _convert_upsample, + "Concat": _convert_concat, + "ConvTranspose": _convert_conv_transpose, + "Sigmoid": _convert_sigmoid, + "Flatten": _convert_Flatten, + "Resize": _convert_resize, + "Transpose": _convert_transpose, + "Softmax": _convert_softmax, +} + +