Browse Source

Refactor onnx utils, deal with graph node without shape, dynamic inference sub-graph value, and add mapper of reshape.

tags/v1.1.0
liuchongming 5 years ago
parent
commit
b30e144f04
9 changed files with 198 additions and 104 deletions
  1. +36
    -0
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  2. +1
    -1
      mindinsight/mindconverter/graph_based_converter/framework.py
  3. +2
    -4
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  4. +1
    -11
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  5. +44
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py
  6. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json
  7. +6
    -5
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  8. +7
    -5
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  9. +99
    -77
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 36
- 0
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -13,6 +13,9 @@
# limitations under the License.
# ============================================================================
"""Define common utils."""
from importlib import import_module
from typing import List

from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP


@@ -27,3 +30,36 @@ def is_converted(operation: str):
bool, true or false.
"""
return operation and SEPARATOR_IN_ONNX_OP not in operation


def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]):
"""
Fetch specific nodes output from onnx model.

Notes:
Only support to get output without batch dimension.

Args:
model (ModelProto): ONNX model.
feed_dict (dict): Feed forward inputs.
output_nodes (list[str]): Output nodes list.

Returns:
dict, nodes' output value.
"""
if not isinstance(feed_dict, dict) or not isinstance(output_nodes, list):
raise TypeError("`feed_dict` should be type of dict, and `output_nodes` "
"should be type of List[str].")

ort = import_module("onnxruntime")

input_nodes = list(feed_dict.keys())

extractor = getattr(import_module("onnx.utils"), "Extractor")(model)
extracted_model = extractor.extract_model(input_nodes, output_nodes)
sess = ort.InferenceSession(path_or_bytes=bytes(extracted_model.SerializeToString()))
fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict)
run_result = dict()
for idx, opt in enumerate(output_nodes):
run_result[opt] = fetched_res[idx]
return run_result

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -92,7 +92,7 @@ def tf_installation_validation(func):
output_folder: str, report_folder: str = None,
input_nodes: str = None, output_nodes: str = None):
# Check whether tensorflow is installed.
if not find_spec("tensorflow") or not find_spec("tf2onnx"):
if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnxruntime"):
error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using "
"graph based scripts converter.")
log.error(str(error))


+ 2
- 4
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -104,11 +104,9 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
return None, dict(), None, dict()

try:
converter_name = op_name_converter(
params=params, weights=weights, op_name=op_name)
converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
converted_params = params_converter(params=params, weights=weights)
converted_weights = weights_converter(
weights=weights) if weights else dict()
converted_weights = weights_converter(weights=weights) if weights else dict()
converted_params.update(converted_weights)
converted_settings = settings_converter(params=params, weights=weights)
except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:


+ 1
- 11
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
import re
import numpy as np
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting
@@ -77,16 +76,7 @@ class ConvMapper(ONNXToMindSporeMapper):
weights = kwargs['weights']
params = kwargs['params']
# regex to find Conv weight
regex = r".+\/(Conv2D|depthwise)\/ReadVariableOp:0$"
regex2 = r"const_fold_opt__\d+"
weight = None
for w_name, w in weights.items():
if re.match(regex, w_name):
weight = w
break
if re.match(regex2, w_name):
weight = w
break
weight = list(weights.values())[0]
if weight is None:
raise ValueError("Conv. Mapper cannot get the weight.")



+ 44
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py View File

@@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class ReshapeMapper(ONNXToMindSporeMapper):
"""Reshape mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Reshape"

@staticmethod
def _convert_params(**kwargs):
return dict()

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
weights = kwargs.get("weights")
if not weights:
return Setting()
if len(weights) > 1:
raise ValueError("For reshape, `weights` length should equal to 1.")
shape = [-1]
shape += list(weights.values())[0][1:].tolist()
return Setting(op_extra_input={"shape": tuple(shape)})

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -14,5 +14,6 @@
"onnx::Clip": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper",
"onnx::Transpose": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.transpose_mapper.TransposeMapper",
"onnx::MatMul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.mat_mul_mapper.MatMulMapper",
"onnx::Softmax": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.softmax_mapper.SoftmaxMapper"
"onnx::Softmax": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.softmax_mapper.SoftmaxMapper",
"onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper"
}

+ 6
- 5
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -95,6 +95,8 @@ class Graph(BaseGraph, abc.ABC):
def __init__(self, model, **kwargs):
super(Graph, self).__init__()
self.model = model
self._raw_input_nodes = kwargs.get("input_nodes")
self._raw_output_nodes = kwargs.get("output_nodes")
self.checkpoint = kwargs.get("checkpoint", None)
self._nodes_collection = OrderedDict()
self._nodes_record = dict()
@@ -112,7 +114,7 @@ class Graph(BaseGraph, abc.ABC):
name (str): Node name.

Returns:
list, shape.
Union[list, int], shape.
"""
return self._input_shape.get(name)

@@ -124,7 +126,7 @@ class Graph(BaseGraph, abc.ABC):
name (str): Node name.

Returns:
list, shape.
Union[list, int],
"""
return self._shape_dict.get(name)

@@ -254,8 +256,7 @@ class Graph(BaseGraph, abc.ABC):
cls, graph instance.
"""
src_graph = cls.load_graph(graph_path=model_path, **kwargs)
ckpt = cls.load_checkpoint(
ckpt_path=checkpoint) if checkpoint else None
ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None

if ckpt is not None:
# Create an instance of TensorflowGraph.
@@ -263,7 +264,7 @@ class Graph(BaseGraph, abc.ABC):
checkpoint=ckpt)

# Create an instance of PyTorchGraph.
return cls(src_graph, sample_shape=sample_shape)
return cls(src_graph, sample_shape=sample_shape, **kwargs)


class GraphNode(abc.ABC):


+ 7
- 5
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -55,8 +55,8 @@ class OnnxGraph(Graph):
sample_shape (tuple): Input shape of the model.
"""

def __init__(self, model, sample_shape: tuple = None):
super(OnnxGraph, self).__init__(model=model)
def __init__(self, model, sample_shape: tuple = None, **kwargs):
super(OnnxGraph, self).__init__(model=model, **kwargs)

self.build(sample_shape)

@@ -118,7 +118,9 @@ class OnnxGraph(Graph):
Args:
input_shape (tuple): Input shape of model. Default: None
"""
model_data = OnnxDataLoader(self.model, graph_input_shape=input_shape)
model_data = OnnxDataLoader(self.model, graph_input_shape=input_shape,
input_nodes=self._raw_input_nodes,
output_nodes=self._raw_output_nodes)
from ..sub_graph_searcher import generate_scope_name
scope_name_list = generate_scope_name(model_data)

@@ -129,7 +131,7 @@ class OnnxGraph(Graph):
inputs = node.input_name_list
# check each input from node or tensors
for i in inputs:
if i in model_data.tensor_name_set:
if i in model_data.tensors_dict:
tensor = model_data.tensors_dict[i]
t_name = tensor.name
t_value = tensor.to_array()
@@ -142,7 +144,7 @@ class OnnxGraph(Graph):
self._build_connection(nd_ipt_name, node_name)

super(OnnxGraph, self).build(input_shape=input_shape)
self._collect_input_shape_of_each_node(input_shape) # diff than pyTorch
self._collect_input_shape_of_each_node(input_shape)

def _collect_input_shape_of_each_node(self, input_shape):
"""


+ 99
- 77
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -19,14 +19,17 @@ from importlib import import_module
from collections import OrderedDict
from typing import Union

import numpy as np

from mindinsight.mindconverter.common.log import logger as log
from ..common.utils import fetch_output_from_onnx_model

from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitFail, ModelNotSupport


def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None):
def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
"""
Convert Tensorflow model to ONNX model.

@@ -104,18 +107,20 @@ class OnnxTensor:
raw_tensor (onnx.TensorProto): onnx.TensorProto instance.
"""

def __init__(self, raw_tensor):
def __init__(self, raw_tensor, name=None):
self.raw_tensor = raw_tensor
self.name = raw_tensor.name
self.type = raw_tensor.data_type
self.dim = raw_tensor.dims
self.name = raw_tensor.name if not isinstance(raw_tensor, np.ndarray) else name
self.type = raw_tensor.data_type if not isinstance(raw_tensor, np.ndarray) else raw_tensor.dtype
self.dim = raw_tensor.dims if not isinstance(raw_tensor, np.ndarray) else raw_tensor.shape
self.from_nodes = []
self.to_nodes = []

def to_array(self):
onnx = import_module("onnx")
# Convert binary data to np.array
return onnx.numpy_helper.to_array(self.raw_tensor)
if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)):
return onnx.numpy_helper.to_array(self.raw_tensor)
return self.raw_tensor


class ParamsAttribute:
@@ -249,22 +254,27 @@ class OnnxDataLoader:

Args:
onnx_model (onnx.ModelProto): Original Onnx defined model.
input_nodes (Union[str, list]): Input nodes of ONNX model.
output_nodes (Union[str, list]): Output nodes of ONNX model.
infer_shape (bool): Enable the shape inference after conversion.
Default: True
"""

def __init__(self, onnx_model, graph_input_shape: Union[tuple, list] = None, infer_shape=True):
def __init__(self, onnx_model, graph_input_shape: Union[tuple, list],
input_nodes: list, output_nodes: list, infer_shape=True):
self.model = onnx_model
self.graph = onnx_model.graph
self.nodes = onnx_model.graph.node
self.graph_input_shape = graph_input_shape
self.input_nodes = input_nodes if isinstance(input_nodes, list) else [input_nodes]
self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes]
# args for init
self._is_infer_shape = infer_shape

# params parsed in init
self.inferred_model = None

self.nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE
self._nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE
self.tensors_dict = {} # {tensor_name: OnnxTensor}
self.value_info_dict = {} # Not contains input and output nodes

@@ -274,9 +284,21 @@ class OnnxDataLoader:

# Key is edge of ONNX ir graph, value is the corresponding precursor node.
self.output_name_to_node_name = dict()
self.dynamic_reshape_node = list()
self.eliminated_nodes = list()

self.initialize()

@property
def nodes_dict(self):
"""Return a filtered nodes_dict."""
filtered_dict = dict()
for k, v in self._nodes_dict.items():
if k in self.eliminated_nodes:
continue
filtered_dict[k] = v
return filtered_dict

def _check_initialization(self):
"""Define conditions checked before init."""
if all([self.model, self.graph, self.nodes]):
@@ -364,7 +386,7 @@ class OnnxDataLoader:
"""Parse each onnx nodes in the model."""
for node in self.nodes:
n = OnnxNode(node)
self.nodes_dict[n.name] = n
self._nodes_dict[n.name] = n
self.node_name_set.add(n.name)
if len(node.output) > 1:
raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.")
@@ -407,76 +429,25 @@ class OnnxDataLoader:

def get_node(self, node_name):
"""Get the OnnxNode instance by node name."""
return self.nodes_dict[node_name]
return self._nodes_dict[node_name]

def get_tensor(self, tensor_name):
"""Get the OnnxTensor instance by tensor name."""
return self.tensors_dict[tensor_name]

def build_tensor_dataflow(self):
"""Find the data from/to nodes of each tensor."""
for node_name, node in self.nodes_dict.items():
# for each input of a node
for input_name in node.input_name_list:
# if the input is a tensor
if input_name in self.tensor_name_set:
t = self.get_tensor(input_name)
t.to_nodes.append(node_name)

for output_name in node.output_name_list:
# if the output is a tensor
if output_name in self.tensor_name_set:
t = self.get_tensor(output_name)
t.from_nodes.append(node_name)

def build_nodes_connection(self):
"""Find the previous and next nodes of each node."""
for node_name, node in self.nodes_dict.items():
# for each input of a node
for node_name, node in self._nodes_dict.items():
if node_name in self.eliminated_nodes:
continue
for input_name in node.input_name_list:
# remove :0 in the name to ensure consistency in hierarchical tree.
input_name = input_name.split(':')[0]
if input_name in self.node_name_set:
# input is a node
# build connection
node.precursor_onnx_node_dict[input_name] = self.get_node(
input_name)

# Back tracing successor nodes
back_tracked_node = self.get_node(input_name)
back_tracked_node.successor_onnx_node_dict[node_name] = self.get_node(node_name)
if self.output_name_to_node_name.get(input_name):
input_node_name = self.output_name_to_node_name.get(input_name)
input_node = self.get_node(input_node_name)
node.precursor_onnx_node_dict[input_node_name] = input_node
input_node.successor_onnx_node_dict[node_name] = node
continue

# check if nodes connected by a tensor
if input_name in self.tensor_name_set:
# regex to remove the ':0' in the name
regex = r'(?P<node>.+)/(?P<op>.+:0)'
match = re.match(regex, input_name)
if not match:
continue

n_name = match.group('node')
if n_name in self.node_name_set:
# current node has a pre node via tensor
node.precursor_onnx_node_dict[n_name] = self.get_node(
n_name)

# Back tracing successor nodes
back_tracked_node = self.get_node(n_name)
back_tracked_node.successor_onnx_node_dict[n_name] = self.get_node(
n_name)
continue

# input_name not a node /tensor but intermediate
for nm, n in self.nodes_dict.items():
for out_name in n.output_name_list:
out_name = out_name.split(':')[0]
if out_name == input_name:
node.precursor_onnx_node_dict[nm] = n

# Back tracing
n.successor_onnx_node_dict[node_name] = node

def initialize(self):
"""Initialize the OnnxDataLoader."""

@@ -493,13 +464,6 @@ class OnnxDataLoader:
if self._is_infer_shape:
try:
self._infer_model()
except Exception as e:
log.error(str(e))
log.exception(e)
raise e

if self.inferred_model:
try:
self._parse_value_info()
self._parse_node_output_shape()
except Exception as e:
@@ -510,5 +474,63 @@ class OnnxDataLoader:
# 3. parse all tensors
self._parse_tensors()

# 4. build nodes connections
# 4. Optimize graph to eliminate some nodes.
self._find_nodes_to_be_eliminated()

# 5. build nodes connections
self.build_nodes_connection()

# 6. Run onnx model to fetch actual value of eliminated nodes.
self._fetch_eliminated_nodes_value()

def _fetch_eliminated_nodes_value(self):
"""Fetch eliminated nodes values by running onnx inference."""

def _for_reshape():
"""Do reshape nodes."""
nonlocal self
output_tensors = []
if not self.dynamic_reshape_node:
return
for node in self.dynamic_reshape_node:
shape_ref = self._nodes_dict[node].input_name_list[1]
output_tensors.append(shape_ref)
feed_dict = {self.input_nodes[0]: np.random.rand(*self.graph_input_shape).astype(np.float32)}
fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors)
for opt_tensor_name, value in fetch_dict.items():
self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name)

_for_reshape()

def _find_nodes_to_be_eliminated(self):
"""Call all PASS to optimize graph."""
for nd_name, nd_inst in self._nodes_dict.items():
self._pass_of_shape(nd_name, nd_inst)

def _pass_of_shape(self, nd_name, nd_inst):
"""Create a PASS to optimize shape and reshape operations in ONNX ir graph."""
to_be_eliminated_op = {"Cast", "Concat", "Squeeze", "Unsqueeze", "Slice",
"Gather", "Shape"}

def _traceback_precursor_nodes_until_shape_op(node_ref):
nonlocal self
e_nodes = []
node = self._nodes_dict[self.output_name_to_node_name[node_ref]]
if node.op_type not in to_be_eliminated_op:
return e_nodes
e_nodes.append(node.name)
for ipt in node.input_name_list:
if ipt not in self.tensors_dict:
e_nodes += _traceback_precursor_nodes_until_shape_op(ipt)
return e_nodes

if nd_inst.op_type == "Reshape":
# Find its shape input.
to_shape = nd_inst.input_name_list[1]
if to_shape in self.tensors_dict:
# Then its shape input is constant.
return

eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape)
self.dynamic_reshape_node.append(nd_name)
self.eliminated_nodes += eliminated_nodes

Loading…
Cancel
Save