Browse Source

tf2ms dev:

Add ops and nn mappers
Replace import method in onnx_utils
Support multi args in statement generation
Sub graph search path bug fix
Add shape check in onnx_utils
tags/v1.1.0
liangtianshu 5 years ago
parent
commit
266478c101
17 changed files with 1039 additions and 79 deletions
  1. +36
    -0
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  2. +58
    -13
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  3. +4
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  4. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  5. +84
    -22
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  6. +36
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py
  7. +30
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py
  8. +40
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py
  9. +43
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py
  10. +4
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json
  11. +1
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  12. +8
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  13. +17
    -9
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  14. +50
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py
  15. +207
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  16. +378
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py
  17. +40
    -21
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 36
- 0
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -13,14 +13,42 @@
# limitations under the License.
# ==============================================================================
"""Hierarchical tree module."""
import re
from mindinsight.mindconverter.common.log import logger as log
from .hierarchical_tree import HierarchicalTree
from ..third_party_graph.onnx_graph_node import OnnxGraphNode

__all__ = [
"HierarchicalTreeFactory"
]


def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name):
"""
Rename the node name by combining scope name and its original name.

Args:
node (OnnxGraphNode): OnnxGraphNode instance.
node_name (str): node name saved in Graph.

Returns:
str, re-formatted node name.
"""
scope_name = node.scope_name
new_name = None
parent = ""
regex = r"(?P<parent>.+/)(?P<op>\w+)"
match = re.match(regex, scope_name)
parent = match.group("parent")
node_name = '$' + node_name + '$'

if scope_name:
new_name = parent + node_name
if new_name:
return new_name
return node_name


class HierarchicalTreeFactory:
"""Hierarchical tree factory."""

@@ -36,6 +64,7 @@ class HierarchicalTreeFactory:
HierarchicalTree, tree.
"""
tree = HierarchicalTree()
node_scope_name = dict()
for _, node_name in enumerate(graph.nodes_in_topological_order):
node_inst = graph.get_node(node_name)
node_input = graph.get_input_shape(node_name)
@@ -44,6 +73,13 @@ class HierarchicalTreeFactory:
err_msg = f"This model is not supported now. " \
f"Cannot find {node_name}'s input shape."
log.error(err_msg)
if isinstance(node_inst, OnnxGraphNode):
node_name_with_scope = _tf_model_node_name_reformat(
node_inst, node_name)
node_scope_name[node_name] = node_name_with_scope
node_name = node_name_with_scope

tree.insert(node_inst, node_name, node_input, node_output)
if node_scope_name:
return tree, node_scope_name
return tree

+ 58
- 13
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -27,6 +27,7 @@ from mindinsight.mindconverter.common.log import logger as log
from .name_mgr import ModuleNameMgr, GlobalVarNameMgr
from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
from ..constant import NodeType
@@ -59,6 +60,8 @@ class HierarchicalTree(Tree):
# Manage variable name in a module.
self._vars_mgr_in_module = dict()
self._module_vars = dict()
# scope name mapping record for easy node searching
self._scope_name_map = dict()

@property
def tree_identifier(self):
@@ -70,13 +73,23 @@ class HierarchicalTree(Tree):
"""
return self.identifier

def insert(self, node: PyTorchGraphNode, node_name: str, input_shape, output_shape):
def get_node(self, nid):
"""Override get_node method to support tf ver. generated scope."""
if nid is None or not self.contains(nid):
if self._scope_name_map and nid in self._scope_name_map:
nid = self._scope_name_map.get(nid)
else:
return None
return self._nodes[nid]

def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode],
node_name: str, input_shape, output_shape):
"""
Insert node into hierarchical tree.

Args:
node_name (str): Node name.
node (PyTorchGraphNode): Node to be inserted.
node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted.
output_shape (tuple): Output tensor shape.
input_shape (tuple): Input tensor shape.

@@ -102,7 +115,12 @@ class HierarchicalTree(Tree):

if not self.contains(identifier):
# Insert node into tree.
tgt_node = node if idx == len(scopes) - 1 else PyTorchGraphNode()
if isinstance(node, OnnxGraphNode):
tgt_node = node if idx == len(
scopes) - 1 else OnnxGraphNode()
else:
tgt_node = node if idx == len(
scopes) - 1 else PyTorchGraphNode()
tgt_node.successor_nodes = node.successor_nodes
tgt_node.precursor_nodes = node.precursor_nodes
tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1
@@ -154,7 +172,8 @@ class HierarchicalTree(Tree):

def save_source_files(self, out_folder: str, mapper: Mapper,
model_name: str,
report_folder: str = None) -> NoReturn:
report_folder: str = None,
scope_name_map: dict = None) -> NoReturn:
"""
Save source codes to target folder.

@@ -165,6 +184,8 @@ class HierarchicalTree(Tree):
out_folder (str): Output folder.

"""
if scope_name_map:
self._scope_name_map = scope_name_map
try:
self._adjust_structure()
code_fragments = self._generate_codes(mapper)
@@ -217,7 +238,8 @@ class HierarchicalTree(Tree):
Node, node.
"""
if module_key in self._merged_module_args:
node = self._clear_unused_args(node, self._merged_module_args[module_key])
node = self._clear_unused_args(
node, self._merged_module_args[module_key])
else:
node.data.clear_args_of_declaration()
return node
@@ -341,12 +363,21 @@ class HierarchicalTree(Tree):
nd_inst = self._preprocess_node_args(nd_inst, module_key)
# 4. Post-process child node args.
for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)):
self._postprocess_node_args(self.get_node(scsr_nd_name), module_key)
self._postprocess_node_args(
self.get_node(scsr_nd_name), module_key)
# 5. Generate code.
snippets.add(func(nd_inst, nd_inst.data.module_name, module_key))
snippets.add(
func(nd_inst, nd_inst.data.module_name, module_key))

code_blocks.extend(snippets)

if self._scope_name_map: # from tf. conversion
c_blocks = []
for s in code_blocks:
s = s.replace('$', '')
c_blocks.append(s)
code_blocks = c_blocks

formatted_code, _ = FormatCode("".join(code_blocks),
style_config=CodeFormatConfig.PEP8.value)
report_generator = ReportGenerator()
@@ -469,8 +500,16 @@ class HierarchicalTree(Tree):
# Generate code statement.
init, construct = self._generate_stat(nd_inst, node, idx)

construct_block.append(construct)
init_block.append(init)
# support multiple construct and init block returns:
if isinstance(construct, list):
construct_block += construct
else:
construct_block.append(construct)

if isinstance(init, list):
init_block += init
else:
init_block.append(init)

class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \
f"{NEW_LINE}{SECOND_LEVEL_INDENT}"
@@ -507,7 +546,8 @@ class HierarchicalTree(Tree):

if idx != 0:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
ipt_args_in_construct = self._get_previous_opt_var(
cur_nd_inst, pre_nd_inst)
if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1:
# Set opt variable name.
opt_arg_in_construct = cur_nd_inst.data.opt_var_name
@@ -652,7 +692,8 @@ class HierarchicalTree(Tree):
nd_inst.data.variable_name = self._module_vars[module_key][idx]
else:
variable_name = nd_inst.data.op_name or nd_inst.data.module_name
variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name)
variable_name = self._vars_mgr_in_module[module_key].get_name(
variable_name)
nd_inst.data.variable_name = variable_name

# Generation of params must behind variable assigment.
@@ -662,7 +703,8 @@ class HierarchicalTree(Tree):
module_settings.update(nd_inst.data.settings_in_code)

if not created:
self._module_vars[module_key].append(nd_inst.data.variable_name)
self._module_vars[module_key].append(
nd_inst.data.variable_name)

node.data.args_in_code = module_args

@@ -727,5 +769,8 @@ class HierarchicalTree(Tree):
Returns:
str, imported module.
"""
return f"from mindspore import nn{NEW_LINE}" \
return f"import numpy as np{NEW_LINE}" \
f"import mindspore{NEW_LINE}" \
f"from mindspore import nn{NEW_LINE}" \
f"from mindspore import Tensor{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

+ 4
- 2
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -104,9 +104,11 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
return None, dict(), dict()

try:
converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
converter_name = op_name_converter(
params=params, weights=weights, op_name=op_name)
converted_params = params_converter(params=params, weights=weights)
converted_weights = weights_converter(weights=weights) if weights else dict()
converted_weights = weights_converter(
weights=weights) if weights else dict()
converted_params.update(converted_weights)
converted_settings = settings_converter(params=params)
except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:


+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -28,9 +28,9 @@ class BatchNormMapper(ONNXToMindSporeMapper):
def _convert_params(**kwargs):
params = kwargs['params']
return {
'num_features': params['output_shape'][1],
'eps': params['epsilon'],
'momentum': params['momentum']
'num_features': params.get('output_shape')[1],
'eps': params.get('epsilon', 1e-5),
'momentum': params.get('momentum', 0.9)
}

@staticmethod


+ 84
- 22
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -13,24 +13,33 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
import re
import numpy as np
from ...base import ONNXToMindSporeMapper


class ConvMapper(ONNXToMindSporeMapper):
"""Conv2d mapper."""
def _convert_padding(**kwargs):
"""Convert padding."""
params = kwargs['params']
if not params.get('pads'):
return '\"same\"', 0
if sum(params['pads']) == 0:
return '\"valid\"', 0
pads_onnx = params['pads']
half_index = len(pads_onnx) // 2
padding = []
for num_begin, num_end in zip(pads_onnx[:half_index], pads_onnx[half_index:]):
padding += [num_begin, num_end]
return '\"pad\"', tuple(padding)

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
weight = kwargs['weights']['weight'].numpy()
dim = weight.ndim - 2
return f"nn.Conv{dim}d"

class ConvMapper(ONNXToMindSporeMapper):
"""Conv2d mapper."""
@staticmethod
def _convert_params(**kwargs):
def convert_params_torch(**kwargs):
"""Convert params from PyTorch to MindSpore"""
weights = kwargs['weights']
params = kwargs['params']

weight = weights['weight'].numpy()
weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0])
if isinstance(params['dilations'], list):
@@ -49,7 +58,7 @@ class ConvMapper(ONNXToMindSporeMapper):
kernel_size = kernel_size[0]
else:
kernel_size = tuple(kernel_size)
pad_mode, padding = ConvMapper._convert_padding(params=params)
pad_mode, padding = _convert_padding(params=params)
return {
'in_channels': in_channels,
'out_channels': out_channels,
@@ -61,21 +70,74 @@ class ConvMapper(ONNXToMindSporeMapper):
'group': params['group']}

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()
def convert_params_tf(**kwargs):
"""Convert params from Tensorflow to MindSpore"""
weights = kwargs['weights']
params = kwargs['params']
# regex to find Conv weight
regex = r".+\/Conv2D\/ReadVariableOp:0$"
weight = None
for w_name, w in weights.items():
if re.match(regex, w_name):
weight = w
break
if weight is None:
raise ValueError("Conv. Mapper cannot get the weight.")

# tmp tf translated ver. mapping
if isinstance(params.get('dilations'), list):
dilation = tuple(params.get('dilations'))
else:
dilation = params.get('dilations')

if isinstance(params.get('strides'), list):
stride = tuple(params.get('strides'))
else:
stride = params.get('strides')

kernel_size = params.get('kernel_shape')
in_channels = weight.shape[1]
out_channels = weight.shape[0]
if len(kernel_size) == 1:
kernel_size = kernel_size[0]
else:
kernel_size = tuple(kernel_size)

pad_mode, padding = _convert_padding(params=params)

return {
'in_channels': in_channels,
'out_channels': out_channels,
'kernel_size': kernel_size,
'stride': stride,
'padding': padding,
'pad_mode': pad_mode,
'dilation': dilation,
'group': params.get('group', 1)}

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
if not kwargs['weights'].get('weight'): # is from tf
kernel_size = kwargs['params'].get('kernel_shape')
dim = len(kernel_size)
return f"nn.Conv{dim}d"

weight = kwargs['weights']['weight'].numpy()
dim = weight.ndim - 2
return f"nn.Conv{dim}d"

@staticmethod
def _convert_padding(**kwargs):
"""Convert padding."""
def _convert_params(**kwargs):
weights = kwargs['weights']
params = kwargs['params']
if sum(params['pads']) == 0:
return '\"valid\"', 0
pads_onnx = params['pads']
half_index = len(pads_onnx) // 2
padding = []
for num_begin, num_end in zip(pads_onnx[:half_index], pads_onnx[half_index:]):
padding += [num_begin, num_end]
return '\"pad\"', tuple(padding)
if not weights.get('weight'): # is from tf
return ConvMapper.convert_params_tf(params=params, weights=weights)
return ConvMapper.convert_params_torch(params=params, weights=weights)
@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):


+ 36
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py View File

@@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class MatMulMapper(ONNXToMindSporeMapper):
"""MatMul mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "nn.MatMul"

@staticmethod
def _convert_params(**kwargs):
return dict()

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return dict()

+ 30
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py View File

@@ -16,6 +16,25 @@
from ...base import ONNXToMindSporeMapper


def _padding_format_convert(padding: list):
"""Convert Onnx padding format to Mindspore"""
num = len(padding)
if num % 2 != 0:
raise ValueError(f"Padding list should be even length but got {num}")

low = 0
mid = num // 2
lst = []
ms_pad_front = low
ms_pad_back = mid
while ms_pad_front < mid and ms_pad_back < num:
lst.append((padding[ms_pad_front], padding[ms_pad_back]))
ms_pad_front += 1
ms_pad_back += 1

return tuple(lst)


class PadMapper(ONNXToMindSporeMapper):
"""Pad mapper."""

@@ -26,16 +45,24 @@ class PadMapper(ONNXToMindSporeMapper):
@staticmethod
def _convert_params(**kwargs):
params = kwargs['params']
if params['mode'] == 'constant':
mode = params.get('mode', 'constant')
if mode == 'constant' and params.get('value') is None:
if params.get('pads'):
pads_onnx = params.get('pads')
if isinstance(pads_onnx, list):
paddings = _padding_format_convert(pads_onnx)
return {'paddings': paddings,
'mode': '\"CONSTANT\"'}
if mode == 'constant':
if params['value'] == 0:
mode = '\"CONSTANT\"'
else:
msg = "{UNSUPPORTED: value is NOT 0}\"CONSTANT\""
mode = msg
elif params['mode'] == 'reflect':
elif mode == 'reflect':
mode = '\"REFLECT\"'
else:
msg = f"{{UNSUPPORTED: \"{params['mode']}\"}}\"UNKNOWN\""
msg = f"{{UNSUPPORTED: \"{mode}\"}}\"UNKNOWN\""
mode = msg
pads_onnx = params['pads']
half_index = len(pads_onnx) // 2


+ 40
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py View File

@@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class SoftmaxMapper(ONNXToMindSporeMapper):
"""Softmax mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "nn.Softmax"

@staticmethod
def _convert_params(**kwargs):
params = kwargs.get('params')
converted_params = {}
if params.get('axis'):
converted_params['axis'] = params.get('axis')
return converted_params

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return dict()

+ 43
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py View File

@@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class TransposeMapper(ONNXToMindSporeMapper):
"""Transpose mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Transpose"

@staticmethod
def _convert_params(**kwargs):
return dict()

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
converted_params = {}
params = kwargs.get('params')
perm = params.get('perm')
if perm and isinstance(perm, list):
perm = tuple(perm)
converted_params['input_perm'] = perm

return {'values': converted_params}

+ 4
- 1
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -11,5 +11,8 @@
"onnx::Pad": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pad_mapper.PadMapper",
"onnx::ReduceMean": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reduce_mean_mapper.ReduceMeanMapper",
"onnx::Concat": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.concat_mapper.ConcatMapper",
"onnx::Clip": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper"
"onnx::Clip": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper",
"onnx::Transpose": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.transpose_mapper.TransposeMapper",
"onnx::MatMul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.mat_mul_mapper.MatMulMapper",
"onnx::Softmax": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.softmax_mapper.SoftmaxMapper"
}

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -98,7 +98,7 @@ class MergedONNXNode(BaseNode):
"""Define merged onnx node."""

def __init__(self, name, module_name, ori_nodes):
super(MergedONNXNode, self).__init__(name, module_name)
super(MergedONNXNode, self).__init__(node_name=name, op_type=module_name)
self.nodes = ori_nodes

def get_name(self):


+ 8
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -16,13 +16,17 @@
from .base import Graph
from .pytorch_graph import PyTorchGraph
from .pytorch_graph_node import PyTorchGraphNode
from .onnx_graph import OnnxGraph
from .onnx_graph_node import OnnxGraphNode


class GraphFactory:
"""Graph factory."""

@classmethod
def init(cls, graph_path: str, sample_shape: tuple, checkpoint: str = None):
def init(cls, graph_path: str,
input_nodes: str, output_nodes: str,
sample_shape: tuple):
"""
Init an instance of graph.

@@ -34,8 +38,9 @@ class GraphFactory:
Returns:
Graph, graph instance.
"""
if checkpoint:
pass
if all([input_nodes, output_nodes]):
return OnnxGraph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes, sample_shape=sample_shape)

return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape)



+ 17
- 9
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -26,7 +26,7 @@ class GraphParser(metaclass=abc.ABCMeta):

@classmethod
@abc.abstractmethod
def parse(cls, model_path: str):
def parse(cls, model_path: str, **kwargs):
"""Parse graph into readable format."""


@@ -54,18 +54,19 @@ class BaseGraph(metaclass=abc.ABCMeta):

@staticmethod
@abc.abstractmethod
def load_graph(graph_path: str):
def load_graph(graph_path: str, **kwargs):
"""Load graph file."""

@classmethod
@abc.abstractmethod
def load(cls, model_path: str, sample_shape: tuple = None,
checkpoint: str = None):
checkpoint: str = None, **kwargs):
"""Factory method to initialize an graph object."""

def __new__(cls, *args, **kwargs):
"""Control the create action of graph."""
model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL)
model_param = args[0] if args else kwargs.get(
cls._REQUIRED_PARAM_OF_MODEL)
if not model_param:
error = ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` "
f"can not be None.")
@@ -229,12 +230,12 @@ class Graph(BaseGraph, abc.ABC):
raise NotImplementedError

@staticmethod
def load_graph(graph_path: str):
def load_graph(graph_path: str, **kwargs):
raise NotImplementedError

@classmethod
def load(cls, model_path: str, sample_shape: tuple = None,
checkpoint: str = None) -> BaseGraph:
checkpoint: str = None, **kwargs) -> BaseGraph:
"""
Load third party graph, metadata and checkpoint.

@@ -245,12 +246,19 @@ class Graph(BaseGraph, abc.ABC):
model_path (str): Graph or model file path.
sample_shape (tuple): Input shape of the model.
checkpoint (str): Checkpoint file path.
input_nodes (list[str]): list of input nodes' name
output_nodes (list[str]): list of output nodes' name

Returns:
cls, graph instance.
"""
src_graph = cls.load_graph(graph_path=model_path)
ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None
tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes')
src_graph = cls.load_graph(graph_path=model_path,
input_nodes=tf_input_nodes,
output_nodes=tf_output_nodes)
ckpt = cls.load_checkpoint(
ckpt_path=checkpoint) if checkpoint else None

if ckpt is not None:
# Create an instance of TensorflowGraph.
@@ -258,7 +266,7 @@ class Graph(BaseGraph, abc.ABC):
checkpoint=ckpt)

# Create an instance of PyTorchGraph.
return cls(model=src_graph, sample_shape=sample_shape)
return cls(src_graph, sample_shape=sample_shape)


class GraphNode(abc.ABC):


+ 50
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py View File

@@ -22,7 +22,7 @@ class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""

@classmethod
def parse(cls, model_path: str):
def parse(cls, model_path: str, **kwargs):
"""
Parser pytorch graph.

@@ -61,3 +61,52 @@ class PyTorchGraphParser(GraphParser):
raise Exception(error_msg)

return model


class TFGraphParser(GraphParser):
"""Define TF graph parser."""

@classmethod
def parse(cls, model_path: str, **kwargs):
"""
Parse TF Computational Graph File (.pb)

Args:
model_path (str): Model file path.

Returns:
object, ONNX model.
"""

from .onnx_utils import convert_tf_graph_to_onnx

tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes')
if not os.path.exists(model_path):
error = FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")
log.error(str(error))
log.exception(error)
raise error

try:
model = convert_tf_graph_to_onnx(model_path,
model_inputs=tf_input_nodes,
model_outputs=tf_output_nodes,
) # need pass more args

except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(error_msg)
log.exception(error)
raise error
except Exception as e:
error_msg = "Error occurs in loading model, make sure model.pb correct."
log.error(error_msg)
log.exception(e)
raise Exception(error_msg)

return model

+ 207
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -0,0 +1,207 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define ONNX graph."""
from typing import Dict, NoReturn

from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .onnx_graph_node import OnnxGraphNode
from .graph_parser import TFGraphParser
from .onnx_utils import OnnxDataLoader


NONE_SCOPE_OP = {
"onnx::Add": "Add",
"onnx::Flatten": "Flatten",
"onnx::Concat": "Concat",
"onnx::Squeeze": "Squeeze",
"onnx::Unsqueeze": "Unsqueeze",
}


def normalize_node_name(node):
"""
Rename the node name by removing :0

Args:
node (Node, str): ONNX node instance or node name string.

Returns:
str, normalized node name.
"""
if isinstance(node, str):
return node.split(':')[0]
return node.name.split(':')[0]


class OnnxGraph(Graph):
"""
Define ONNX graph.

Args:
model (onnx.ModelProto): Onnx defined model proto.
sample_shape (tuple): Input shape of the model.
"""

def __init__(self, model, sample_shape: tuple = None):
super(OnnxGraph, self).__init__(model=model)

self.build(sample_shape)

def _extract_shape(self, shape):
"""
Extract shape from string-type shape.

Args:
shape (str): Shape value in string-type.

Returns:
list, shape.
"""
if "," not in shape:
return []

shape_arr = []
for s in shape.split(","):
s = s.strip()
if not s:
return []
if ":" in s:
s = s.split(":")[0]
s = s.replace("!", "")
if not s.isdigit():
return []
shape_arr.append(int(s))
return shape_arr

def _build_connection(self, src, tgt) -> NoReturn:
"""
Build connection between source node and target node.

Args:
src (str): Source node name.
tgt (str): Target node name.
"""
# If src and tgt are the same node, src not in node_collection or
# tgt not in node_collection, then skip this edge.
src = normalize_node_name(src)
tgt = normalize_node_name(tgt)
if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection:
if src.split(':')[0] not in self._nodes_collection:
log.warning(
"Graph construct a self-loop node %s. Ignored.", src)
return
if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes:
self._nodes_collection[src.split(
':')[0]].successor_nodes.append(tgt)
if src not in self._nodes_collection[tgt].precursor_nodes:
self._nodes_collection[tgt.split(
':')[0]].precursor_nodes.append(src)

def build(self, input_shape=None):
"""
Build graph tree.

Args:
input_shape (tuple): Input shape of model. Default: None
"""
model_data = OnnxDataLoader(self.model, graph_input_shape=input_shape)
from ..sub_graph_searcher import generate_scope_name
scope_name_list = generate_scope_name(model_data)

self._shape_dict = model_data.normalize_dict_key(
model_data.node_output_shape_dict)
for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()):
node_weight = {}
node.scope_name = scope_name_list[ind]
inputs = node.input_name_list
# check each input from node or tensors
for i in inputs:
if i in model_data.tensor_name_set:
tensor = model_data.tensors_dict[i]
t_name = tensor.name
t_value = tensor.to_array()
node_weight[t_name] = t_value
self._nodes_collection[node_name] = OnnxGraphNode(
node, node_weight)
self._nodes_record[node_name] = node_name

for node_input in node.input_name_list:
self._build_connection(node_input, node_name)

super(OnnxGraph, self).build(input_shape=input_shape)
self._collect_input_shape_of_each_node(
input_shape) # diff than pyTorch

def _collect_input_shape_of_each_node(self, input_shape):
"""
Collect input tensor shape of each node.

Args:
input_shape (tuple): Input shape.
"""
input_node = InputNode(input_shape)
input_node_name = "{}InputNode"
for node_name, node in self._nodes_collection.items():
if node_name in self._input_nodes:
ipt_nd_name = input_node_name.format(input_node.scope_name)
input_node.set_scope_name(node.scope_name)
node.precursor_nodes.insert(0, ipt_nd_name)
input_node.set_successor_nodes(node_name)
self._shape_dict[ipt_nd_name] = input_node.output_shape

ipt_shape = []
for p_nd in node.precursor_nodes:
shp = self._shape_dict.get(p_nd)
ipt_shape.append(tuple(shp))

self._input_shape[node_name] = ipt_shape[0] if len(
ipt_shape) == 1 else ipt_shape

def sub_graph_merging(self):
raise NotImplementedError()

@staticmethod
def load_checkpoint(ckpt_path: str) -> Dict:
raise NotImplementedError()

@staticmethod
def load_metadata(**kwargs):
raise NotImplementedError()

@staticmethod
def load_graph(graph_path: str, **kwargs):
"""
Load graph.

Note:
The input/output nodes are optional for
tf saved model format. But required for .pb & .ckpt

Args:
graph_path (str): Graph path.
tf_input_nodes (str): input nodes of tf graph
tf_output_nodes (str): output nodes of tf graph

Returns:
object, ONNX model.
"""
tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes')
onnx_model = TFGraphParser.parse(graph_path,
input_nodes=tf_input_nodes,
output_nodes=tf_output_nodes)
return onnx_model

+ 378
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py View File

@@ -0,0 +1,378 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define ONNX graph node."""

from copy import deepcopy
from .base import GraphNode

from ..constant import NodeType, SEPARATOR_IN_SCOPE, \
SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP
from ..mapper.base import Mapper


class OnnxGraphNode(GraphNode):
"""
ONNX Graph Node.

Args:
node (OnnxNode): OnnxNode Object.
weight (dict): Dictionary records weight and bias.
"""
_type_frozen = False
_module_name_frozen = False

def __init__(self, node=None, weight=None):
super(OnnxGraphNode, self).__init__(node=node)
self._op_params = self._get_raw_params(node.raw_node) if node else None
self._op_name = "onnx::" + node.op_type if node else None
self._scope_name = node.scope_name if node else None
self._opt_var_name = None
self._variable_name = self._extract_var_name(self._scope_name)
self._module_name = None
self._weight = weight

def clear_args_of_declaration(self):
"""Clear `self._args_in_code`."""
self._args_in_code = dict()

def _get_arg_name(self, arg):
"""
Get arg name.

Args:
arg (str): Generate arg name.

Returns:
str, arg name in function or class declaration.
"""
return f"{arg}_{self._variable_name}"

@property
def hash_key(self):
"""
Return unique hash key of current node.

Returns:
str, hash key.
"""
if self._node_type not in {NodeType.CLASS.value,
NodeType.FUNC.value,
NodeType.MODULE.value}:
self._hash_key = self._op_name.lower()
return self._hash_key

@hash_key.setter
def hash_key(self, h):
"""
Setter of hash key.

Args:
h (str): Key.
"""
self._hash_key = h

@property
def variable_name(self):
"""
Variable name.

Returns:
str, variable name declared in init.
"""
return self._variable_name

@variable_name.setter
def variable_name(self, v):
"""
Setter of variable name.

Args:
v (str): Variable name.
"""
self._variable_name = v

@property
def module_name(self):
"""
Module name.

Returns:
str, module name.
"""
if not self._module_name_frozen:
module_name = self.tag
return module_name

return self._module_name

def _froze_module_name(self, m):
"""
Once module_name is set, then it's unchangeable.

Args:
m (str): Module name.
"""
if not self._module_name_frozen:
self._module_name = m
self._module_name_frozen = True

@property
def op_name(self):
"""
Op name in onnx.

Returns:
str, op name
"""
return self._op_name

@property
def real_name(self):
return

def add_input_and_output_shape(self, input_shape, output_shape):
"""
Add the node input shape.

Args:
output_shape (tuple): Output tensor shape.
input_shape (tuple): Input tensor shape.
"""
self._ipt_shape = input_shape
self._opt_shape = output_shape

def _add_tensor_args_to_code(self, op_name: str, t_identifier: str, declare, args):
"""
Add nn used tensors to args in init and construct blocks.

Args:
op_name (str): Add the tensor to args if the current node has this
op_name.
t_identifier (str): The unique strinf appeared in the target tensor
name.
declare_s (str): Declare statement generated in to_code().
init_s (str): init statement generated in to_code().

Returns:
declare_list list, multiple declare statements.
input_args list, multiple input args generated statements.
"""
if not self._op_name == op_name:
return declare, args
declare_list = []
tensor = None
# find target tensor
for t_name, t_value in self._weight.items():
if t_identifier in t_name:
tensor = t_value
break
if tensor is None:
return declare, args
declare_list.append(declare)
declare_t = f"self.{self._variable_name}_w = Tensor(" \
f"np.random.uniform(0, 1, {str(tensor.shape)}), mindspore.float32)"
declare_list.append(declare_t)
args += f", self.{self._variable_name}_w"
return declare_list, args

def to_code(self, ipt_args_in_construct: str, output_var: str):
"""
Generate statements.

Args:
ipt_args_in_construct (str): Args of input.
output_var (str): Output variable name in construct.

Returns:
Union[str, str], declare in init and call in construct.
"""
operator = self.op_in_ms or self.module_name
self._opt_var_name = output_var

args = self.args_in_code
settings = self.settings_in_code
if self._node_type == NodeType.OPERATION.value and not self.convert_successful():
args.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

if self._node_type == NodeType.OPERATION.value:
expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = \
self._generate_ipt_args_settings_in_construct(
ipt_args_in_construct,
settings)
else:
# When it's type is module, class or func,
# it's not necessary to replace var.
expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = ipt_args_in_construct
declare = f"self.{self._variable_name} = {operator}({expr})"

# Extra Tensor generator for nn.MatMul
declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code(
'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct)

# Extra Tensor generator for onnx::BiasAdd
declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code(
'onnx::MatMul', 'BiasAdd', declare, ipt_args_settings_in_construct)

call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})"

return declare, call

@staticmethod
def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings):
"""
Generate input with args and settings in construct.

Args:
ipt_args_in_construct(str): Input args in construct.
settings(dict): Settings in operator.

Returns:
str, args of each node in generated construct statement.
"""
if settings.get('input_type'):
input_type = settings['input_type']
if input_type == InputType.TENSOR.value:
ipt_args_settings_in_construct = ipt_args_in_construct
elif input_type == InputType.LIST.value:
ipt_args_settings_in_construct = f"({ipt_args_in_construct})"
else:
raise NodeInputTypeNotSupport(
f"Input type[{input_type}] is not supported now.")
else:
ipt_args_settings_in_construct = ipt_args_in_construct

if settings.get('values'):
settings_value = settings['values']
if settings_value:
settings_in_construct = ', '.join(
[f"{setting_val}" for _, setting_val in settings_value.items()])
ipt_args_settings_in_construct = ', '.join(
(ipt_args_settings_in_construct, settings_in_construct))

return ipt_args_settings_in_construct

def to_ir(self):
"""No need to implement for now."""
raise NotImplementedError

def _get_raw_params(self, node):
"""
Get params in onnx.
Note: parameters are attributes in node.

Args:
node (onnx.NodeProto): Onnx defined node proto.

Returns:
dict, raw params.
"""
import onnx

raw_params = dict()

if not node:
return raw_params

for attribute in node.attribute:
name = attribute.name
value = onnx.helper.get_attribute_value(attribute)
raw_params[name] = value

return raw_params

def replace_with_arg(self, src_arg, tgt_arg):
"""
Replace actual parameter with formal parameter.

Args:
src_arg (str): Original arg name.
tgt_arg (str): Target arg name.

"""
self._args_in_code[src_arg] = tgt_arg

@staticmethod
def _extract_var_name(scope_name: str):
"""Extract variable name from scope name."""
if not scope_name:
return None
var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower()
var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace(
RIGHT_BUCKET, "")
return var

def param_transform(self, mapper: Mapper):
"""
Transform torch params into mindspore.

Args:
mapper (Mapper): Mapper of params.

"""
if self._node_type != NodeType.OPERATION.value:
args = deepcopy(self._args_in_code)
self._args_in_code = dict()
for arg, value in args.items():
self._args_in_code[self._get_arg_name(arg)] = value
return None, None

if not self.transformed:
_, _, _ = super(OnnxGraphNode, self).param_transform(mapper)

for arg, value in self._params_in_ms.items():
self._args_in_code[self._get_arg_name(arg)] = value

for arg, value in self._settings_in_ms.items():
self._settings_in_code[arg] = value

self.transformed = True

return self._op_in_ms, self._params_in_ms, self._settings_in_ms

def froze_node_type_and_module_name(self, node_type, module_name):
"""
Froze node type and module name.

After node_type is frozen, then the `module_name`
will be affected when `node_type` is `class`.
Thus, this line must be placed before `nd_inst.data.module_name`.

Args:
module_name: Modified module name.
node_type (str): Node type, class of func.

"""
if not self._type_frozen:
self._node_type = node_type
self._type_frozen = True

if not self._module_name_frozen:
self._froze_module_name(module_name)

def convert_successful(self):
"""
Whether convert successfully.

Returns:
bool, true or false.
"""
if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms:
return True
return False

+ 40
- 21
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -15,7 +15,9 @@
"""Define ONNX related operations."""
import re
import abc
from importlib import import_module
from collections import OrderedDict
from typing import Union

from mindinsight.mindconverter.common.log import logger as log

@@ -39,10 +41,15 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None
Returns:
onnx.ModelProto, onnx defined model proto.
"""
import tensorflow as tf
from tf2onnx.tfonnx import process_tf_graph
from tf2onnx import constants, utils, optimizer
from tf2onnx import tf_loader
tf = import_module('tensorflow')
tf2onnx = import_module("tf2onnx")
tfonnx = getattr(tf2onnx, "tfonnx")
process_tf_graph = getattr(tfonnx, "process_tf_graph")
constants = getattr(tf2onnx, "constants")
utils = getattr(tf2onnx, "utils")
optimizer = getattr(tf2onnx, "optimizer")
tf_loader = getattr(tf2onnx, "tf_loader")

target = ",".join(constants.DEFAULT_TARGET)
shape_override = None

@@ -80,13 +87,9 @@ class OnnxTensor:
"""
Define Onnx Tensor structure for convenience.

Note:
parameter from_nodes and to_nodes.

Args:
raw_tensor (onnx.TensorProto): onnx.TensorProto instance.
"""
import onnx

def __init__(self, raw_tensor):
self.raw_tensor = raw_tensor
@@ -97,7 +100,8 @@ class OnnxTensor:
self.to_nodes = []

def to_array(self):
"""Convert binary data to np.array"""
onnx = import_module("onnx")
# Convert binary data to np.array
return onnx.numpy_helper.to_array(self.raw_tensor)


@@ -136,7 +140,6 @@ class ParamsAttribute:
for attribute in attrs:
self.attribute_name_list.append(attribute.name)
type_num = attribute.type

# get attribute value by determining its type
# Can Convert to np.array if needed
if type_num == ONNX_TYPE_INTS:
@@ -219,8 +222,8 @@ class OnnxNode(BaseNode):
self.raw_node = raw_node
self.params = ParamsAttribute(raw_node.attribute, raw_node)
self.scope_name = None
self.input_name_list = raw_node.input
self.output_name_list = raw_node.output
self.input_name_list = getattr(raw_node, 'input')
self.output_name_list = getattr(raw_node, 'output')


class OnnxDataLoader:
@@ -238,11 +241,11 @@ class OnnxDataLoader:
Default: True
"""

def __init__(self, onnx_model, infer_shape=True):
def __init__(self, onnx_model, graph_input_shape: Union[tuple, list] = None, infer_shape=True):
self.model = onnx_model
self.graph = onnx_model.graph
self.nodes = onnx_model.graph.node
self.graph_input_shape = graph_input_shape
# args for init
self._is_infer_shape = infer_shape

@@ -251,9 +254,7 @@ class OnnxDataLoader:

self.nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE
self.tensors_dict = {} # {tensor_name: OnnxTensor}
self.weight_dict = {} # {tensor_name: OnnxTensor} NOT USED
self.bias_dict = {} # {tensor_name: OnnxTensor} NOT USED
# {node_name : (type, dim)} NO INPUT & OUTPUT NODE!
# {node_name : (type, dim)} NO INPUT & OUTPUT NODE
self.value_info_dict = {}

self.tensor_name_set = set() # [str]
@@ -265,6 +266,20 @@ class OnnxDataLoader:
def _check_initialization(self):
"""Define conditions checked before init."""
if all([self.model, self.graph, self.nodes]):
if self.graph_input_shape is None: # do not check
return True
onnx = import_module("onnx")
# check input shape eligible
input_node = getattr(self.graph, 'input')[0]
type_str = onnx.helper.printable_type(input_node.type)
regex = r".*(unk.+)x(?P<h>\d+)x(?P<w>\d+)x(?P<c>\d+)"
match = re.match(regex, type_str)
h = int(match.group('h'))
w = int(match.group('w'))
c = int(match.group('c'))
if [h, w, c] != list(self.graph_input_shape)[1:4]:
raise ValueError(
f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
return True
return False

@@ -276,12 +291,12 @@ class OnnxDataLoader:
The method will be replaced by self-implemented
in future development.
"""
import onnx
onnx = import_module("onnx")
self.inferred_model = onnx.shape_inference.infer_shapes(self.model)

def _parse_value_info(self): # no input node & output node
"""Parse onnx defined value_info class attributes."""
import onnx
"""Parse onnx defined value_info class attribtues"""
onnx = import_module("onnx")

def _parse_value_info_re(i):
"""
@@ -341,7 +356,11 @@ class OnnxDataLoader:
# replace unknown shape by '-1'
for s in shape_list:
if 'unk' in s:
s = '-1'
if self.graph_input_shape is not None:
s = self.graph_input_shape[0]
else:
s = '1'

# convert str to int
s = int(s)
lst.append(s)


Loading…
Cancel
Save