|
|
|
@@ -19,14 +19,17 @@ from importlib import import_module |
|
|
|
from collections import OrderedDict |
|
|
|
from typing import Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mindinsight.mindconverter.common.log import logger as log |
|
|
|
from ..common.utils import fetch_output_from_onnx_model |
|
|
|
|
|
|
|
from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ |
|
|
|
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL |
|
|
|
from ...common.exceptions import GraphInitFail, ModelNotSupport |
|
|
|
|
|
|
|
|
|
|
|
def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None): |
|
|
|
def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): |
|
|
|
""" |
|
|
|
Convert Tensorflow model to ONNX model. |
|
|
|
|
|
|
|
@@ -104,18 +107,20 @@ class OnnxTensor: |
|
|
|
raw_tensor (onnx.TensorProto): onnx.TensorProto instance. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, raw_tensor): |
|
|
|
def __init__(self, raw_tensor, name=None): |
|
|
|
self.raw_tensor = raw_tensor |
|
|
|
self.name = raw_tensor.name |
|
|
|
self.type = raw_tensor.data_type |
|
|
|
self.dim = raw_tensor.dims |
|
|
|
self.name = raw_tensor.name if not isinstance(raw_tensor, np.ndarray) else name |
|
|
|
self.type = raw_tensor.data_type if not isinstance(raw_tensor, np.ndarray) else raw_tensor.dtype |
|
|
|
self.dim = raw_tensor.dims if not isinstance(raw_tensor, np.ndarray) else raw_tensor.shape |
|
|
|
self.from_nodes = [] |
|
|
|
self.to_nodes = [] |
|
|
|
|
|
|
|
def to_array(self): |
|
|
|
onnx = import_module("onnx") |
|
|
|
# Convert binary data to np.array |
|
|
|
return onnx.numpy_helper.to_array(self.raw_tensor) |
|
|
|
if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): |
|
|
|
return onnx.numpy_helper.to_array(self.raw_tensor) |
|
|
|
return self.raw_tensor |
|
|
|
|
|
|
|
|
|
|
|
class ParamsAttribute: |
|
|
|
@@ -249,22 +254,27 @@ class OnnxDataLoader: |
|
|
|
|
|
|
|
Args: |
|
|
|
onnx_model (onnx.ModelProto): Original Onnx defined model. |
|
|
|
input_nodes (Union[str, list]): Input nodes of ONNX model. |
|
|
|
output_nodes (Union[str, list]): Output nodes of ONNX model. |
|
|
|
infer_shape (bool): Enable the shape inference after conversion. |
|
|
|
Default: True |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, onnx_model, graph_input_shape: Union[tuple, list] = None, infer_shape=True): |
|
|
|
def __init__(self, onnx_model, graph_input_shape: Union[tuple, list], |
|
|
|
input_nodes: list, output_nodes: list, infer_shape=True): |
|
|
|
self.model = onnx_model |
|
|
|
self.graph = onnx_model.graph |
|
|
|
self.nodes = onnx_model.graph.node |
|
|
|
self.graph_input_shape = graph_input_shape |
|
|
|
self.input_nodes = input_nodes if isinstance(input_nodes, list) else [input_nodes] |
|
|
|
self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes] |
|
|
|
# args for init |
|
|
|
self._is_infer_shape = infer_shape |
|
|
|
|
|
|
|
# params parsed in init |
|
|
|
self.inferred_model = None |
|
|
|
|
|
|
|
self.nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE |
|
|
|
self._nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE |
|
|
|
self.tensors_dict = {} # {tensor_name: OnnxTensor} |
|
|
|
self.value_info_dict = {} # Not contains input and output nodes |
|
|
|
|
|
|
|
@@ -274,9 +284,21 @@ class OnnxDataLoader: |
|
|
|
|
|
|
|
# Key is edge of ONNX ir graph, value is the corresponding precursor node. |
|
|
|
self.output_name_to_node_name = dict() |
|
|
|
self.dynamic_reshape_node = list() |
|
|
|
self.eliminated_nodes = list() |
|
|
|
|
|
|
|
self.initialize() |
|
|
|
|
|
|
|
@property |
|
|
|
def nodes_dict(self): |
|
|
|
"""Return a filtered nodes_dict.""" |
|
|
|
filtered_dict = dict() |
|
|
|
for k, v in self._nodes_dict.items(): |
|
|
|
if k in self.eliminated_nodes: |
|
|
|
continue |
|
|
|
filtered_dict[k] = v |
|
|
|
return filtered_dict |
|
|
|
|
|
|
|
def _check_initialization(self): |
|
|
|
"""Define conditions checked before init.""" |
|
|
|
if all([self.model, self.graph, self.nodes]): |
|
|
|
@@ -364,7 +386,7 @@ class OnnxDataLoader: |
|
|
|
"""Parse each onnx nodes in the model.""" |
|
|
|
for node in self.nodes: |
|
|
|
n = OnnxNode(node) |
|
|
|
self.nodes_dict[n.name] = n |
|
|
|
self._nodes_dict[n.name] = n |
|
|
|
self.node_name_set.add(n.name) |
|
|
|
if len(node.output) > 1: |
|
|
|
raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") |
|
|
|
@@ -407,76 +429,25 @@ class OnnxDataLoader: |
|
|
|
|
|
|
|
def get_node(self, node_name): |
|
|
|
"""Get the OnnxNode instance by node name.""" |
|
|
|
return self.nodes_dict[node_name] |
|
|
|
return self._nodes_dict[node_name] |
|
|
|
|
|
|
|
def get_tensor(self, tensor_name): |
|
|
|
"""Get the OnnxTensor instance by tensor name.""" |
|
|
|
return self.tensors_dict[tensor_name] |
|
|
|
|
|
|
|
def build_tensor_dataflow(self): |
|
|
|
"""Find the data from/to nodes of each tensor.""" |
|
|
|
for node_name, node in self.nodes_dict.items(): |
|
|
|
# for each input of a node |
|
|
|
for input_name in node.input_name_list: |
|
|
|
# if the input is a tensor |
|
|
|
if input_name in self.tensor_name_set: |
|
|
|
t = self.get_tensor(input_name) |
|
|
|
t.to_nodes.append(node_name) |
|
|
|
|
|
|
|
for output_name in node.output_name_list: |
|
|
|
# if the output is a tensor |
|
|
|
if output_name in self.tensor_name_set: |
|
|
|
t = self.get_tensor(output_name) |
|
|
|
t.from_nodes.append(node_name) |
|
|
|
|
|
|
|
def build_nodes_connection(self): |
|
|
|
"""Find the previous and next nodes of each node.""" |
|
|
|
for node_name, node in self.nodes_dict.items(): |
|
|
|
# for each input of a node |
|
|
|
for node_name, node in self._nodes_dict.items(): |
|
|
|
if node_name in self.eliminated_nodes: |
|
|
|
continue |
|
|
|
for input_name in node.input_name_list: |
|
|
|
# remove :0 in the name to ensure consistency in hierarchical tree. |
|
|
|
input_name = input_name.split(':')[0] |
|
|
|
if input_name in self.node_name_set: |
|
|
|
# input is a node |
|
|
|
# build connection |
|
|
|
node.precursor_onnx_node_dict[input_name] = self.get_node( |
|
|
|
input_name) |
|
|
|
|
|
|
|
# Back tracing successor nodes |
|
|
|
back_tracked_node = self.get_node(input_name) |
|
|
|
back_tracked_node.successor_onnx_node_dict[node_name] = self.get_node(node_name) |
|
|
|
if self.output_name_to_node_name.get(input_name): |
|
|
|
input_node_name = self.output_name_to_node_name.get(input_name) |
|
|
|
input_node = self.get_node(input_node_name) |
|
|
|
node.precursor_onnx_node_dict[input_node_name] = input_node |
|
|
|
input_node.successor_onnx_node_dict[node_name] = node |
|
|
|
continue |
|
|
|
|
|
|
|
# check if nodes connected by a tensor |
|
|
|
if input_name in self.tensor_name_set: |
|
|
|
# regex to remove the ':0' in the name |
|
|
|
regex = r'(?P<node>.+)/(?P<op>.+:0)' |
|
|
|
match = re.match(regex, input_name) |
|
|
|
if not match: |
|
|
|
continue |
|
|
|
|
|
|
|
n_name = match.group('node') |
|
|
|
if n_name in self.node_name_set: |
|
|
|
# current node has a pre node via tensor |
|
|
|
node.precursor_onnx_node_dict[n_name] = self.get_node( |
|
|
|
n_name) |
|
|
|
|
|
|
|
# Back tracing successor nodes |
|
|
|
back_tracked_node = self.get_node(n_name) |
|
|
|
back_tracked_node.successor_onnx_node_dict[n_name] = self.get_node( |
|
|
|
n_name) |
|
|
|
continue |
|
|
|
|
|
|
|
# input_name not a node /tensor but intermediate |
|
|
|
for nm, n in self.nodes_dict.items(): |
|
|
|
for out_name in n.output_name_list: |
|
|
|
out_name = out_name.split(':')[0] |
|
|
|
if out_name == input_name: |
|
|
|
node.precursor_onnx_node_dict[nm] = n |
|
|
|
|
|
|
|
# Back tracing |
|
|
|
n.successor_onnx_node_dict[node_name] = node |
|
|
|
|
|
|
|
def initialize(self): |
|
|
|
"""Initialize the OnnxDataLoader.""" |
|
|
|
|
|
|
|
@@ -493,13 +464,6 @@ class OnnxDataLoader: |
|
|
|
if self._is_infer_shape: |
|
|
|
try: |
|
|
|
self._infer_model() |
|
|
|
except Exception as e: |
|
|
|
log.error(str(e)) |
|
|
|
log.exception(e) |
|
|
|
raise e |
|
|
|
|
|
|
|
if self.inferred_model: |
|
|
|
try: |
|
|
|
self._parse_value_info() |
|
|
|
self._parse_node_output_shape() |
|
|
|
except Exception as e: |
|
|
|
@@ -510,5 +474,63 @@ class OnnxDataLoader: |
|
|
|
# 3. parse all tensors |
|
|
|
self._parse_tensors() |
|
|
|
|
|
|
|
# 4. build nodes connections |
|
|
|
# 4. Optimize graph to eliminate some nodes. |
|
|
|
self._find_nodes_to_be_eliminated() |
|
|
|
|
|
|
|
# 5. build nodes connections |
|
|
|
self.build_nodes_connection() |
|
|
|
|
|
|
|
# 6. Run onnx model to fetch actual value of eliminated nodes. |
|
|
|
self._fetch_eliminated_nodes_value() |
|
|
|
|
|
|
|
def _fetch_eliminated_nodes_value(self): |
|
|
|
"""Fetch eliminated nodes values by running onnx inference.""" |
|
|
|
|
|
|
|
def _for_reshape(): |
|
|
|
"""Do reshape nodes.""" |
|
|
|
nonlocal self |
|
|
|
output_tensors = [] |
|
|
|
if not self.dynamic_reshape_node: |
|
|
|
return |
|
|
|
for node in self.dynamic_reshape_node: |
|
|
|
shape_ref = self._nodes_dict[node].input_name_list[1] |
|
|
|
output_tensors.append(shape_ref) |
|
|
|
feed_dict = {self.input_nodes[0]: np.random.rand(*self.graph_input_shape).astype(np.float32)} |
|
|
|
fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors) |
|
|
|
for opt_tensor_name, value in fetch_dict.items(): |
|
|
|
self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) |
|
|
|
|
|
|
|
_for_reshape() |
|
|
|
|
|
|
|
def _find_nodes_to_be_eliminated(self): |
|
|
|
"""Call all PASS to optimize graph.""" |
|
|
|
for nd_name, nd_inst in self._nodes_dict.items(): |
|
|
|
self._pass_of_shape(nd_name, nd_inst) |
|
|
|
|
|
|
|
def _pass_of_shape(self, nd_name, nd_inst): |
|
|
|
"""Create a PASS to optimize shape and reshape operations in ONNX ir graph.""" |
|
|
|
to_be_eliminated_op = {"Cast", "Concat", "Squeeze", "Unsqueeze", "Slice", |
|
|
|
"Gather", "Shape"} |
|
|
|
|
|
|
|
def _traceback_precursor_nodes_until_shape_op(node_ref): |
|
|
|
nonlocal self |
|
|
|
e_nodes = [] |
|
|
|
node = self._nodes_dict[self.output_name_to_node_name[node_ref]] |
|
|
|
if node.op_type not in to_be_eliminated_op: |
|
|
|
return e_nodes |
|
|
|
e_nodes.append(node.name) |
|
|
|
for ipt in node.input_name_list: |
|
|
|
if ipt not in self.tensors_dict: |
|
|
|
e_nodes += _traceback_precursor_nodes_until_shape_op(ipt) |
|
|
|
return e_nodes |
|
|
|
|
|
|
|
if nd_inst.op_type == "Reshape": |
|
|
|
# Find its shape input. |
|
|
|
to_shape = nd_inst.input_name_list[1] |
|
|
|
if to_shape in self.tensors_dict: |
|
|
|
# Then its shape input is constant. |
|
|
|
return |
|
|
|
|
|
|
|
eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape) |
|
|
|
self.dynamic_reshape_node.append(nd_name) |
|
|
|
self.eliminated_nodes += eliminated_nodes |