Browse Source

!917 Refactor GraphNode and optimize code generation.

From: @liuchongming74
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
42486f3f9a
32 changed files with 584 additions and 566 deletions
  1. +15
    -0
      mindinsight/mindconverter/graph_based_converter/common/__init__.py
  2. +218
    -0
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  3. +29
    -0
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  4. +5
    -4
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  5. +62
    -55
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  6. +4
    -0
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py
  7. +5
    -5
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  8. +34
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py
  9. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  10. +3
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  11. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  12. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py
  13. +3
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py
  14. +10
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py
  15. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py
  16. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py
  17. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py
  18. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py
  19. +10
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py
  20. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py
  21. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py
  22. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py
  23. +80
    -102
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  24. +0
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py
  25. +5
    -21
      mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py
  26. +0
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  27. +35
    -180
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py
  28. +3
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  29. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  30. +19
    -156
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py
  31. +1
    -1
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py
  32. +21
    -19
      tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py

+ 15
- 0
mindinsight/mindconverter/graph_based_converter/common/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common instance and utils of graph based converter."""

+ 218
- 0
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -0,0 +1,218 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define CodeLine object."""
import abc


class TrainableParams:
"""Trainable parameters."""

def __init__(self, shape, dtype, reference):
self.param_name = None
self.shape = shape
self.dtype = dtype
self.reference = reference # Weight name in global npy.


class CodeSetting:
"""Code generation settings."""

def __init__(self):
self.output_vars_suffix = []
self.operation_input_type = None # Construct input type, tensor or list.
self.operation_extra_input = dict() # `values` in original setting dict.
self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor


class Fragment(abc.ABC):
"""
Define comment attributes of code generation.

Args:
operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values.
settings (namedTuple): Code generation setting.
"""

def __init__(self, operation, actual_args, input_shape, output_shape, settings=None):
self._operation = operation
self._input_shape = input_shape
self._output_shape = output_shape
self._declared_variable_name = None
self._output_var_name = list() # Output variable name(could be multi-opt).
self._operation_inputs = list() # Index indices the order of input.
self._operation_extra_inputs = settings
self._code_setting = settings
self._formal_args_list = dict()
self._actual_args_list = actual_args # Key is the param_key, value is the corresponding value.
self._node_type = ""

@property
def code_setting(self):
return self._code_setting

@property
def node_type(self):
"""Node type getter."""
return self._node_type

@node_type.setter
def node_type(self, t):
"""Node type setter."""
self._node_type = t

@property
def operation_extra_inputs(self):
"""Getter of extra operation inputs."""
return self._operation_extra_inputs

@property
def declared_var_name(self):
"""Declared variable name getter."""
return self._declared_variable_name

@declared_var_name.setter
def declared_var_name(self, var):
"""Setter of declared variable name."""
self._declared_variable_name = var

@property
def output_var_name(self) -> str:
"""Getter of output variable name."""
return ", ".join(self._output_var_name)

@output_var_name.setter
def output_var_name(self, opt_vars):
"""
Output variable name setter.

Args:
opt_vars (list[str]): Output variable name.
"""
self._output_var_name = opt_vars

@property
def operation_inputs(self):
"""
Operation getter.

Returns:
list[Fragment], list of inputs.
"""
return self._operation_inputs

def update_operation_inputs(self, ipt):
"""
Update operation inputs.

Args:
ipt (Fragment): Where input comes from.
"""
self._operation_inputs.append(ipt)

@property
def operation(self):
"""
Operation getter.

Returns:
str, operation name to be initialized.
"""
return self._operation

@operation.setter
def operation(self, op: str):
"""
Operation setter.

Args:
op (str): Operation name.
"""
self._operation = op

@property
def actual_args(self) -> dict:
"""Getter of actual args."""
return self._actual_args_list

@property
def formal_args(self) -> dict:
"""Get formal args."""
return self._formal_args_list

def update_formal_args(self, formal_args: dict):
"""
Update formal args.

Args:
formal_args (dict): To be updated args.
"""
return self._formal_args_list.update(formal_args)

@property
def input_shape(self):
return self._input_shape

@property
def output_shape(self):
return self._output_shape


class CodeFragment(Fragment):
"""
Manage the variables related with code generation.

For single operation type node, the variables in `CodeLine` stands for:
```python
class Module(nn.Cell):
def __init__ (self, ...):
super(Module, self).__init__()
self.<CodeLine.declared_variable_name> = <CodeLine.operation>(<CodeLine.scalar_args>,
<CodeLine.init_trainable_params>)
self.<CodeLine.trainable_params[k].param_name> = Tensor(<CodeLine.trainable_params[k].shape>,
dtype=<CodeLine._trainable_params[k].dtype>)

def construct(self, x, ...):
<CodeLine.output_var_name> = self.<CodeLine.declared_variable_name>(<CodeLine.operation_inputs>)
...
return output
```

Args:
operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values.
settings (namedTuple): Code generation setting.
"""

def __init__(self, operation, actual_args, settings, input_shape, output_shape,
trainable_params=None):
super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape,
settings=settings)
self._trainable_params = dict() # External weights, like Matmul.
self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d.

@property
def trainable_params(self):
return self._trainable_params


class ModuleFragment(Fragment):
"""Manage module type code variables."""

def __init__(self, operation, actual_args, settings, input_shape, output_shape):
super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape,
settings=settings)

+ 29
- 0
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -0,0 +1,29 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define common utils."""
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP


def is_converted(operation: str):
"""
Whether convert successful.

Args:
operation (str): Operation name.

Returns:
bool, true or false.
"""
return operation and SEPARATOR_IN_ONNX_OP not in operation

+ 5
- 4
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Hierarchical tree module."""
import re

from mindinsight.mindconverter.common.log import logger as log
from .hierarchical_tree import HierarchicalTree
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
@@ -36,7 +37,6 @@ def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name):
"""
scope_name = node.scope_name
new_name = None
parent = ""
regex = r"(?P<parent>.+/)(?P<op>\w+)"
match = re.match(regex, scope_name)
parent = match.group("parent")
@@ -74,12 +74,13 @@ class HierarchicalTreeFactory:
f"Cannot find {node_name}'s input shape."
log.error(err_msg)
if isinstance(node_inst, OnnxGraphNode):
node_name_with_scope = _tf_model_node_name_reformat(
node_inst, node_name)
node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name)
node_scope_name[node_name] = node_name_with_scope
node_name = node_name_with_scope

tree.insert(node_inst, node_name, node_input, node_output)
node_inst.add_input_and_output_shape(node_input, node_output)
tree.insert(node_inst, node_name)

if node_scope_name:
return tree, node_scope_name
return tree

+ 62
- 55
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -25,17 +25,18 @@ from treelib import Tree, Node
from mindinsight.mindconverter.common.log import logger as log

from .name_mgr import ModuleNameMgr, GlobalVarNameMgr
from ..common.utils import is_converted
from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig
from ..constant import SEPARATOR_IN_SCOPE
from ..constant import CodeFormatConfig
from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
from ..constant import NodeType
from ..report_generator import ReportGenerator
from ...common.exceptions import NodeTypeNotSupport

GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr()


class HierarchicalTree(Tree):
"""Define hierarchical tree."""
@@ -46,6 +47,8 @@ class HierarchicalTree(Tree):
_root_created = False
ROOT_LEVEL = 0

GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr()

def __init__(self):
super(HierarchicalTree, self).__init__()
self._hierarchical_order = dict()
@@ -62,6 +65,7 @@ class HierarchicalTree(Tree):
self._module_vars = dict()
# scope name mapping record for easy node searching
self._scope_name_map = dict()
self.code_fragment_recorder = dict()

@property
def tree_identifier(self):
@@ -82,19 +86,15 @@ class HierarchicalTree(Tree):
return None
return self._nodes[nid]

def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode],
node_name: str, input_shape, output_shape):
def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str):
"""
Insert node into hierarchical tree.

Args:
node_name (str): Node name.
node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted.
output_shape (tuple): Output tensor shape.
input_shape (tuple): Input tensor shape.

"""
node.add_input_and_output_shape(input_shape, output_shape)
scopes = node_name.split(SEPARATOR_IN_SCOPE)
for idx, scope in enumerate(scopes):
parent = SEPARATOR_IN_SCOPE.join(scopes[:idx])
@@ -125,10 +125,9 @@ class HierarchicalTree(Tree):
tgt_node.precursor_nodes = node.precursor_nodes
tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1
else NodeType.MODULE).value
tgt_node.tag = scope.split(SEPARATOR_BTW_NAME_AND_ID)[0]
tgt_node.variable_name = self._get_var_name(identifier)
self.create_node(
tag=tgt_node.tag,
tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0],
identifier=identifier,
parent=parent,
data=tgt_node
@@ -276,8 +275,7 @@ class HierarchicalTree(Tree):
node.data.replace_with_arg(arg, arg)
return node

@staticmethod
def _clear_unused_args(node, used_args):
def _clear_unused_args(self, node, used_args):
"""
Clear unused args.

@@ -290,7 +288,9 @@ class HierarchicalTree(Tree):
"""
args_in_code = list(node.data.args_in_code.keys())
for arg in args_in_code:
ori_arg = arg.replace(f"_{node.data.variable_name}", "")
ori_arg = arg.replace(
f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", ""
)
if ori_arg not in used_args:
node.data.args_in_code.pop(arg)
return node
@@ -323,6 +323,8 @@ class HierarchicalTree(Tree):
# 1. Generate args for each node in this level.
if node.data.node_type == NodeType.MODULE.value:
self._create_module_args_and_vars(node, mapper)
if depth == depths[-1]:
self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "")

# Module merging based on all nodes.
self._module_merging()
@@ -345,30 +347,29 @@ class HierarchicalTree(Tree):
# then assign the created module name to current node,
# and delete unused args.
module_name = self._created_module[module_key]
nd_inst.data.froze_node_type_and_module_name(node_type,
module_name)
self.code_fragment_recorder[nd_inst.identifier].operation = module_name
self.code_fragment_recorder[nd_inst.identifier].node_type = node_type
self._preprocess_node_args(nd_inst, module_key)
continue

module_name = nd_inst.data.module_name
module_name = nd_inst.tag

if node_type == NodeType.CLASS.value:
module_name = f"{module_name[0].upper()}{module_name[1:]}"

# After node_type and module_name is frozen,
# then it's unchangeable.
module_name = self._module_mgr.get_name(module_name)
nd_inst.data.froze_node_type_and_module_name(node_type,
module_name)
self.code_fragment_recorder[nd_inst.identifier].operation = module_name
self.code_fragment_recorder[nd_inst.identifier].node_type = node_type

# 3. Pre-process node args.
nd_inst = self._preprocess_node_args(nd_inst, module_key)
# 4. Post-process child node args.
for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)):
self._postprocess_node_args(
self.get_node(scsr_nd_name), module_key)
self._postprocess_node_args(self.get_node(scsr_nd_name), module_key)
# 5. Generate code.
snippets.add(
func(nd_inst, nd_inst.data.module_name, module_key))
snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key))

code_blocks.extend(snippets)

@@ -437,7 +438,7 @@ class HierarchicalTree(Tree):
module_list = []
for node_name in node.successors(self.tree_identifier):
c_nd = self.get_node(node_name)
operator = c_nd.data.op_in_ms or c_nd.data.module_name
operator = self.code_fragment_recorder[c_nd.identifier].operation

if c_nd.data.node_type != NodeType.OPERATION.value:
hash_key = c_nd.data.hash_key or self.hash_key(c_nd)
@@ -445,14 +446,16 @@ class HierarchicalTree(Tree):
operator = self._created_module[hash_key]

args = c_nd.data.args_in_code
if c_nd.data.node_type == NodeType.OPERATION.value and \
not c_nd.data.convert_successful():
if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted(
self.code_fragment_recorder[c_nd.identifier].operation):
args.update({"input_shape": c_nd.data.input_shape,
"output_shape": c_nd.data.output_shape})

# Generate code statement.
expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}"
for k, v in args.items()])
expr = ", ".join(
[f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}"
for k, v in args.items()]
)
code_line = f"{operator}({expr})"
module_list.append(code_line)

@@ -547,14 +550,16 @@ class HierarchicalTree(Tree):

if idx != 0:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(
cur_nd_inst, pre_nd_inst)
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1:
# Set opt variable name.
opt_arg_in_construct = cur_nd_inst.data.opt_var_name
opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt"

declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct,
output_var=opt_arg_in_construct)
variable_name=self.code_fragment_recorder[
cur_nd_inst.identifier].declared_var_name,
output_var=opt_arg_in_construct,
code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier])

return declare, call

@@ -588,7 +593,9 @@ class HierarchicalTree(Tree):
if e not in pre_nd.successors(self.tree_identifier):
while True:
if p_nd.identifier in pre_nd.successors(self.tree_identifier):
ipt_lst.append(p_nd.data.opt_var_name)
ipt_lst.append(
f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt"
)
break
pre_nd_name = p_nd.predecessor(self.tree_identifier)
if not pre_nd_name:
@@ -597,7 +604,9 @@ class HierarchicalTree(Tree):
p_nd = self.get_node(pre_nd_name)
continue

ipt_lst.append(p_nd.data.opt_var_name)
ipt_lst.append(
f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt"
)
return ipt_lst

def _get_previous_opt_var(self, cur_nd, pre_nd):
@@ -619,12 +628,11 @@ class HierarchicalTree(Tree):
cur_nd = self.get_node(p_nd[0])
return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd))

def hash_key(self, node, depth: int = 0):
def hash_key(self, node):
"""
Generate hash key for each node.

Args:
depth (int): Recursion depth.
node (Node): Node.

Returns:
@@ -633,13 +641,17 @@ class HierarchicalTree(Tree):
scsr_topo_order = []
for s in node.successors(self.tree_identifier):
cur_nd = self.get_node(s)
if cur_nd.data.hash_key:
scsr_topo_order.append(f"{cur_nd.data.hash_key}[{depth}]")
continue
if cur_nd.data.node_type in {NodeType.MODULE.value,
NodeType.FUNC.value,
NodeType.CLASS.value}:
scsr_topo_order.append(self.hash_key(cur_nd, depth + 1))
if cur_nd.data.hash_key:
scsr_topo_order.append(f"({cur_nd.data.hash_key})")
continue

raise ValueError("Current node doesn't have hash key.")

if cur_nd.data.hash_key:
scsr_topo_order.append(cur_nd.data.hash_key)
continue
unique_key = "->".join(scsr_topo_order)
node.data.hash_key = unique_key
@@ -675,12 +687,11 @@ class HierarchicalTree(Tree):
"""
# All args and value pair in current node module.
module_args = dict()
module_settings = dict()
module_key = self.hash_key(node)
created = False

if module_key not in self._vars_mgr_in_module:
self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR
self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR
self._module_vars[module_key] = []
else:
created = True
@@ -688,33 +699,29 @@ class HierarchicalTree(Tree):
# Sub-modules in the module could have arg name conflicts.
for idx, successor_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(successor_name)
# Generate variable name here, then
# to generate args.
# Generation of params must behind variable assigment.
if created:
nd_inst.data.variable_name = self._module_vars[module_key][idx]
variable_name = self._module_vars[module_key][idx]
else:
variable_name = nd_inst.data.op_name or nd_inst.data.module_name
variable_name = self._vars_mgr_in_module[module_key].get_name(
variable_name)
nd_inst.data.variable_name = variable_name
variable_name = nd_inst.data.op_name or nd_inst.tag
variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name)

# Generation of params must behind variable assigment.
nd_inst.data.param_transform(mapper)
code_fragment = nd_inst.data.param_transform(mapper, variable_name)
code_fragment.declared_var_name = variable_name
self.code_fragment_recorder[nd_inst.identifier] = code_fragment

module_args.update(nd_inst.data.args_in_code)
module_settings.update(nd_inst.data.settings_in_code)

if not created:
self._module_vars[module_key].append(
nd_inst.data.variable_name)
self._module_vars[module_key].append(variable_name)

node.data.args_in_code = module_args

# Collect module args of `module_key`.
if module_key not in self._merged_module:
self._merged_module[module_key] = [node.data.args_in_code]
self._merged_module[module_key] = [deepcopy(node.data.args_in_code)]
else:
self._merged_module[module_key].append(node.data.args_in_code)
self._merged_module[module_key].append(deepcopy(node.data.args_in_code))

@staticmethod
def _create_operation_args(node, mapper):


+ 4
- 0
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py View File

@@ -63,6 +63,10 @@ START_IDX = 0
class GlobalVarNameMgr:
"""Global variable name mgr."""

def __init__(self):
global_op_namespace.clear()
global_var_namespace.clear()

@staticmethod
def _get_name(name):
"""Deal with op name."""


+ 5
- 5
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -87,7 +87,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
module_name = TABLE.get(op_name)

if not module_name:
return None, dict(), dict()
return None, dict(), None, dict()

pos = module_name.rfind(".")
try:
@@ -101,7 +101,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
# If mapper can not be found, then skip it.
err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg)
return None, dict(), dict()
return None, dict(), None, dict()

try:
converter_name = op_name_converter(
@@ -110,13 +110,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
converted_weights = weights_converter(
weights=weights) if weights else dict()
converted_params.update(converted_weights)
converted_settings = settings_converter(params=params)
converted_settings = settings_converter(params=params, weights=weights)
except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg)
return None, dict(), dict()
return None, dict(), None, dict()

return converter_name, converted_params, converted_settings
return converter_name, converted_params, converted_settings, converted_weights

@staticmethod
def _operation_name_in_ms(*args, **kwargs):


+ 34
- 0
mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py View File

@@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operation mapping setting."""
from collections import namedtuple
import numpy as np

from mindinsight.mindconverter.graph_based_converter.constant import InputType

Tensor = namedtuple("Tensor", ["shape", "dtype", "reference"])

Setting = namedtuple("Setting", ["opt_vars_suffix",
"op_ipt_type",
"op_extra_input",
"op_extra_tensor"])
Setting.__new__.__defaults__ = ("_opt", InputType.TENSOR.value, dict(), None)


def get_dtype(tensor: np.ndarray):
"""Get tensor dtype."""
if tensor.dtype == np.float16:
return "mindspore.float16"
return "mindspore.float32"

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class BatchNormMapper(ONNXToMindSporeMapper):
@@ -39,4 +40,4 @@ class BatchNormMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 3
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -16,6 +16,7 @@
import re
import numpy as np
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


def _convert_padding(**kwargs):
@@ -35,6 +36,7 @@ def _convert_padding(**kwargs):

class ConvMapper(ONNXToMindSporeMapper):
"""Conv2d mapper."""

@staticmethod
def convert_params_torch(**kwargs):
"""Convert params from PyTorch to MindSpore"""
@@ -148,4 +150,4 @@ class ConvMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class DenseMapper(ONNXToMindSporeMapper):
@@ -41,4 +42,4 @@ class DenseMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class FlattenMapper(ONNXToMindSporeMapper):
@@ -33,4 +34,4 @@ class FlattenMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 3
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class GlobalPoolMapper(ONNXToMindSporeMapper):
@@ -25,8 +26,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper):
op_name = 'nn.AvgPool{}d'
else:
op_name = 'nn.MaxPool{}d'
dim = 1 if len(kwargs['params']['input_shape']) == 3\
else 2
dim = 1 if len(kwargs['params']['input_shape']) == 3 else 2
return op_name.format(dim)

@staticmethod
@@ -49,4 +49,4 @@ class GlobalPoolMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 10
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting, Tensor, get_dtype


class MatMulMapper(ONNXToMindSporeMapper):
@@ -33,4 +34,12 @@ class MatMulMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
weights = kwargs.get("weights")
if not weights:
return Setting()
tensor, ref = None, ""
for t_name, t_value in weights.items():
tensor = t_value
ref = t_name
return Setting(op_extra_tensor=Tensor(shape=tensor.shape,
dtype=get_dtype(tensor), reference=ref))

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


def _padding_format_convert(padding: list):
@@ -77,4 +78,4 @@ class PadMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class PoolMapper(ONNXToMindSporeMapper):
@@ -49,4 +50,4 @@ class PoolMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class ReLUMapper(ONNXToMindSporeMapper):
@@ -45,4 +46,4 @@ class ReLUMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class SoftmaxMapper(ONNXToMindSporeMapper):
@@ -37,4 +38,4 @@ class SoftmaxMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
return Setting()

+ 10
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting, Tensor, get_dtype


class AddMapper(ONNXToMindSporeMapper):
@@ -33,4 +34,12 @@ class AddMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
return dict()
weights = kwargs.get("weights")
if not weights:
return Setting()
tensor, ref = None, ""
for t_name, t_value in weights.items():
tensor = t_value
ref = t_name
return Setting(op_extra_tensor=Tensor(shape=tensor.shape,
dtype=get_dtype(tensor), reference=ref))

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py View File

@@ -15,6 +15,7 @@
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.constant import InputType
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class ConcatMapper(ONNXToMindSporeMapper):
@@ -36,4 +37,4 @@ class ConcatMapper(ONNXToMindSporeMapper):
@staticmethod
def _convert_settings(**kwargs):
input_type = InputType.LIST.value
return {'input_type': input_type}
return Setting(op_ipt_type=input_type)

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class ReduceMeanMapper(ONNXToMindSporeMapper):
@@ -40,4 +41,4 @@ class ReduceMeanMapper(ONNXToMindSporeMapper):
axis = params['axes'][0] if len(params['axes']) == 1 else tuple(params['axes'])
else:
axis = tuple()
return {'values': {'axis': axis}}
return Setting(op_extra_input={'axis': axis})

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class TransposeMapper(ONNXToMindSporeMapper):
@@ -40,4 +41,4 @@ class TransposeMapper(ONNXToMindSporeMapper):
perm = tuple(perm)
converted_params['input_perm'] = perm

return {'values': converted_params}
return Setting(op_extra_input=converted_params)

+ 80
- 102
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -15,10 +15,13 @@
"""Define graph entity."""
import abc
from collections import OrderedDict
from copy import deepcopy

from mindinsight.mindconverter.common.log import logger as log
from ..constant import SEPARATOR_IN_ONNX_OP
from ..common.code_fragment import CodeFragment
from ..constant import NodeType, InputType
from ..mapper.base import Mapper
from ...common.exceptions import NodeInputTypeNotSupport


class GraphParser(metaclass=abc.ABCMeta):
@@ -287,26 +290,10 @@ class GraphNode(abc.ABC):
self._op_params = dict()
self._scope_name = None
self._op_shape = None
# Operation in mindspore.
self._op_in_ms = None
# Params in mindspore.
self._params_in_ms = dict()
# Settings in mindspore.
self._settings_in_ms = dict()
# Node type of current node, e.g. class, module, operation.
self._node_type = None
# Tag name on tree.
self._tag_on_tree = None
# Function, class or operation needed args.
self._args_in_code = dict()
# Operation needed settings.
self._settings_in_code = dict()
# Variable name declared in init block.
self._variable_name = None
# Output variable name declared in construct block.
self._opt_var_name = None
# Function or class name in code.
self._module_name = None
# Unique key of node.
self._hash_key = None
# Input shape of current op.
@@ -317,37 +304,18 @@ class GraphNode(abc.ABC):
self._weight = None

@property
def opt_var_name(self):
def weight(self):
return self._weight

@staticmethod
def get_opt_var_name(variable_name):
"""
Output variable name.

Returns:
str, variable name.
"""
return f"{self.variable_name}_opt"

@opt_var_name.setter
def opt_var_name(self, v):
"""
Set variable name.

Args:
v (str): Name.

"""
self._opt_var_name = v

@property
def op_in_ms(self):
"""
Operation in mindspore.

Returns:
str, operation name.
"""
if self._op_in_ms and SEPARATOR_IN_ONNX_OP in self._op_in_ms:
return self._op_in_ms.replace(SEPARATOR_IN_ONNX_OP, ".")
return self._op_in_ms
return f"{variable_name}_opt"

@property
def args_in_code(self):
@@ -370,27 +338,6 @@ class GraphNode(abc.ABC):
"""
self._args_in_code = args

@property
def settings_in_code(self):
"""
Settings in code.

Returns:
dict, settings.
"""
return self._settings_in_code

@settings_in_code.setter
def settings_in_code(self, settings):
"""
Settings in code.

Args:
settings(dict): Settings.

"""
self._settings_in_code = settings

@property
def input_shape(self):
"""
@@ -411,16 +358,6 @@ class GraphNode(abc.ABC):
"""
return self._opt_shape

@property
def tag(self):
"""Tag on hierarchical tree."""
return self._tag_on_tree

@tag.setter
def tag(self, t):
"""Tag on hierarchical tree."""
self._tag_on_tree = t

def is_empty(self):
"""
Whether is empty.
@@ -536,7 +473,7 @@ class GraphNode(abc.ABC):
"""Replace actual parameter with formal parameter."""

@abc.abstractmethod
def _get_arg_name(self, arg):
def _get_arg_name(self, arg, variable_name):
"""Get arg name for func or class."""

@abc.abstractmethod
@@ -553,13 +490,8 @@ class GraphNode(abc.ABC):
def real_name(self, **kwargs):
"""Setter of `real_name`."""

@property
@abc.abstractmethod
def variable_name(self):
"""Getter of `variable_name`."""

@abc.abstractmethod
def to_code(self, ipt_args_in_construct: str, output_var: str):
def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment):
"""Graph node to MindSpore code."""

@abc.abstractmethod
@@ -570,40 +502,86 @@ class GraphNode(abc.ABC):
def add_input_and_output_shape(self, input_shape, output_shape):
"""Add the node input shape."""

@abc.abstractmethod
def froze_node_type_and_module_name(self, node_type, module_name):
"""Make node_type can not be changed."""
@staticmethod
def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings):
"""
Generate input with args and settings in construct.

@abc.abstractmethod
def convert_successful(self):
"""Whether convert successful."""
Args:
ipt_args_in_construct (str): Input args in construct.
settings (Setting): Settings in operator.

Returns:
str, args of each node in generated construct statement.
"""
if settings and settings.op_ipt_type:
input_type = settings.op_ipt_type
if input_type == InputType.TENSOR.value:
ipt_args_settings_in_construct = ipt_args_in_construct
elif input_type == InputType.LIST.value:
ipt_args_settings_in_construct = f"({ipt_args_in_construct})"
else:
raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.")
else:
ipt_args_settings_in_construct = ipt_args_in_construct

if settings and settings.op_extra_input:
settings_value = settings.op_extra_input
if settings_value:
settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()])
ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct))

def param_transform(self, mapper: Mapper):
return ipt_args_settings_in_construct

def param_transform(self, mapper: Mapper, variable_name):
"""
Transform param in pytorch operation into mindspore.
Transform param in PyTorch operation into MindSpore.

Args:
variable_name (str): Variable name.
mapper (ONNXToMindSporeMapper): Mapper between onnx operation
and mindspore.
and MindSpore.

Returns:
dict, transformed params.
"""
import copy
params = copy.deepcopy(self._op_params)
if self._node_type != NodeType.OPERATION.value:
args = deepcopy(self._args_in_code)
self._args_in_code = dict()
for arg, value in args.items():
self._args_in_code[self._get_arg_name(arg, variable_name)] = value
return CodeFragment(operation="", actual_args=args, settings=None,
input_shape=self.input_shape, output_shape=self.output_shape)

if self.transformed:
raise ValueError("Already transformed.")

params = deepcopy(self._op_params)
params.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

op_name_in_mindspore, ms_params, ms_settings = mapper.convert(op_name=self.op_name,
params=params,
weights=self._weight)
if op_name_in_mindspore:
self._op_in_ms = op_name_in_mindspore
self._params_in_ms = ms_params
self._settings_in_ms = ms_settings
ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name,
params=params,
weights=self._weight)

if ms_op:
code_fragment = CodeFragment(operation=ms_op,
actual_args=ms_params,
settings=ms_settings,
input_shape=self.input_shape,
output_shape=self.output_shape,
trainable_params=ms_weights)
else:
self._op_in_ms = self._op_name
self._params_in_ms = self._op_params
self._settings_in_ms = dict()
code_fragment = CodeFragment(operation=self._op_name,
actual_args=self._op_params,
settings=None,
input_shape=self.input_shape,
output_shape=self.output_shape,
trainable_params=self._weight)

for arg, value in code_fragment.actual_args.items():
self._args_in_code[self._get_arg_name(arg, variable_name)] = value

self.transformed = True

return self._op_in_ms, self._params_in_ms, self._settings_in_ms
return code_fragment

+ 0
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py View File

@@ -38,7 +38,6 @@ class PyTorchGraphParser(GraphParser):
error = FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")
log.error(str(error))
log.exception(error)
raise error

try:


+ 5
- 21
mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py View File

@@ -21,24 +21,18 @@ from ..constant import SEPARATOR_IN_SCOPE, NodeType

class InputNode(GraphNode):
"""
Pytorch Input Node.
PyTorch Input Node.

Args:
input_shape: Input shape of module.

"""

def convert_successful(self):
"""
Whether convert successful.

Returns:
bool, true or false.
"""
return False
def _get_arg_name(self, arg, variable_name):
raise NotImplementedError()

def froze_node_type_and_module_name(self, node_type, module_name):
pass
def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment):
raise NotImplementedError()

def _get_raw_params(self, node):
pass
@@ -56,9 +50,6 @@ class InputNode(GraphNode):
def replace_with_arg(self, src_arg, tgt_arg):
pass

def _get_arg_name(self, arg):
pass

def add_input_and_output_shape(self, input_shape, output_shape):
pass

@@ -116,15 +107,8 @@ class InputNode(GraphNode):
def real_name(self):
return

@property
def variable_name(self):
return

def to_ir(self):
"""
No need to implement for now.
"""
raise NotImplementedError()

def to_code(self, ipt_args_in_construct: str, output_var: str):
raise NotImplementedError()

+ 0
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -22,7 +22,6 @@ from .onnx_graph_node import OnnxGraphNode
from .graph_parser import TFGraphParser
from .onnx_utils import OnnxDataLoader


NONE_SCOPE_OP = {
"onnx::Add": "Add",
"onnx::Flatten": "Flatten",


+ 35
- 180
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py View File

@@ -13,14 +13,13 @@
# limitations under the License.
# ==============================================================================
"""Define ONNX graph node."""
from importlib import import_module

from copy import deepcopy
from .base import GraphNode
from ..common.utils import is_converted

from ..constant import NodeType, SEPARATOR_IN_SCOPE, \
SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType
from ..mapper.base import Mapper
from ...common.exceptions import NodeInputTypeNotSupport
from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP


class OnnxGraphNode(GraphNode):
@@ -39,16 +38,13 @@ class OnnxGraphNode(GraphNode):
self._op_params = self._get_raw_params(node.raw_node) if node else None
self._op_name = "onnx::" + node.op_type if node else None
self._scope_name = node.scope_name if node else None
self._opt_var_name = None
self._variable_name = self._extract_var_name(self._scope_name)
self._module_name = None
self._weight = weight

def clear_args_of_declaration(self):
"""Clear `self._args_in_code`."""
self._args_in_code = dict()

def _get_arg_name(self, arg):
def _get_arg_name(self, arg, variable_name):
"""
Get arg name.

@@ -58,7 +54,7 @@ class OnnxGraphNode(GraphNode):
Returns:
str, arg name in function or class declaration.
"""
return f"{arg}_{self._variable_name}"
return f"{arg}_{variable_name}"

@property
def hash_key(self):
@@ -84,51 +80,6 @@ class OnnxGraphNode(GraphNode):
"""
self._hash_key = h

@property
def variable_name(self):
"""
Variable name.

Returns:
str, variable name declared in init.
"""
return self._variable_name

@variable_name.setter
def variable_name(self, v):
"""
Setter of variable name.

Args:
v (str): Variable name.
"""
self._variable_name = v

@property
def module_name(self):
"""
Module name.

Returns:
str, module name.
"""
if not self._module_name_frozen:
module_name = self.tag
return module_name

return self._module_name

def _froze_module_name(self, m):
"""
Once module_name is set, then it's unchangeable.

Args:
m (str): Module name.
"""
if not self._module_name_frozen:
self._module_name = m
self._module_name_frozen = True

@property
def op_name(self):
"""
@@ -154,15 +105,13 @@ class OnnxGraphNode(GraphNode):
self._ipt_shape = input_shape
self._opt_shape = output_shape

def _add_tensor_args_to_code(self, op_name: str, t_identifier: str, declare, args):
def _add_tensor_args_to_code(self, op_name: str, settings, declare, args, variable_name):
"""
Add nn used tensors to args in init and construct blocks.

Args:
op_name (str): Add the tensor to args if the current node has this
op_name.
t_identifier (str): The unique string appeared in the target tensor
name.
op_name.
declare (str): Declare statement generated in to_code().
args (str): Args statement generated in to_code().

@@ -172,103 +121,68 @@ class OnnxGraphNode(GraphNode):
"""
if not self._op_name == op_name:
return declare, args
declare_list = []
tensor = None
# find target tensor
for t_name, t_value in self._weight.items():
if t_identifier in t_name:
tensor = t_value
break
if tensor is None:
if not settings or not settings.op_extra_tensor:
return declare, args
declare_list.append(declare)
declare_t = f"self.{self._variable_name}_w = Tensor(" \
f"np.random.uniform(0, 1, {str(tensor.shape)}), mindspore.float32)"
declare_list = [declare]
declare_t = f"self.{variable_name}_w = Tensor(" \
f"np.random.uniform(0, 1, {str(settings.op_extra_tensor.shape)}), " \
f"{settings.op_extra_tensor.dtype})"
declare_list.append(declare_t)
args += f", self.{self._variable_name}_w"
args += f", self.{variable_name}_w"
return declare_list, args

def to_code(self, ipt_args_in_construct: str, output_var: str):
def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str,
code_fragment):
"""
Generate statements.

Args:
variable_name (str): Variable name.
ipt_args_in_construct (str): Args of input.
output_var (str): Output variable name in construct.
code_fragment (CodeFragment): CodeFragment instance.

Returns:
Union[str, str], declare in init and call in construct.
"""
operator = self.op_in_ms or self.module_name
self._opt_var_name = output_var
operator = code_fragment.operation

args = self.args_in_code
settings = self.settings_in_code
if self._node_type == NodeType.OPERATION.value and not self.convert_successful():
settings = code_fragment.code_setting

if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation):
args.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

if self._node_type == NodeType.OPERATION.value:
expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = \
self._generate_ipt_args_settings_in_construct(
ipt_args_in_construct,
settings)
ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(
ipt_args_in_construct, settings)
else:
# When it's type is module, class or func,
# it's not necessary to replace var.
expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = ipt_args_in_construct
declare = f"self.{self._variable_name} = {operator}({expr})"

if SEPARATOR_IN_ONNX_OP in operator:
operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".")

declare = f"self.{variable_name} = {operator}({expr})"

# Extra Tensor generator for nn.MatMul
declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code(
'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct)
'onnx::MatMul', settings, declare, ipt_args_settings_in_construct, variable_name)

# Extra Tensor generator for onnx::Add
declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code(
'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct)
'onnx::Add', settings, declare, ipt_args_settings_in_construct, variable_name)

call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})"
call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})"

return declare, call

@staticmethod
def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings):
"""
Generate input with args and settings in construct.

Args:
ipt_args_in_construct(str): Input args in construct.
settings(dict): Settings in operator.

Returns:
str, args of each node in generated construct statement.
"""
if settings.get('input_type'):
input_type = settings['input_type']
if input_type == InputType.TENSOR.value:
ipt_args_settings_in_construct = ipt_args_in_construct
elif input_type == InputType.LIST.value:
ipt_args_settings_in_construct = f"({ipt_args_in_construct})"
else:
raise NodeInputTypeNotSupport(
f"Input type[{input_type}] is not supported now.")
else:
ipt_args_settings_in_construct = ipt_args_in_construct

if settings.get('values'):
settings_value = settings['values']
if settings_value:
settings_in_construct = ', '.join(
[f"{setting_val}" for _, setting_val in settings_value.items()])
ipt_args_settings_in_construct = ', '.join(
(ipt_args_settings_in_construct, settings_in_construct))

return ipt_args_settings_in_construct

def to_ir(self):
"""No need to implement for now."""
raise NotImplementedError
@@ -284,7 +198,7 @@ class OnnxGraphNode(GraphNode):
Returns:
dict, raw params.
"""
import onnx
onnx = import_module("onnx")

raw_params = dict()

@@ -318,62 +232,3 @@ class OnnxGraphNode(GraphNode):
var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace(
RIGHT_BUCKET, "")
return var

def param_transform(self, mapper: Mapper):
"""
Transform tensorflow params into mindspore.

Args:
mapper (Mapper): Mapper of params.

"""
if self._node_type != NodeType.OPERATION.value:
args = deepcopy(self._args_in_code)
self._args_in_code = dict()
for arg, value in args.items():
self._args_in_code[self._get_arg_name(arg)] = value
return None, None

if not self.transformed:
_, _, _ = super(OnnxGraphNode, self).param_transform(mapper)

for arg, value in self._params_in_ms.items():
self._args_in_code[self._get_arg_name(arg)] = value

for arg, value in self._settings_in_ms.items():
self._settings_in_code[arg] = value

self.transformed = True

return self._op_in_ms, self._params_in_ms, self._settings_in_ms

def froze_node_type_and_module_name(self, node_type, module_name):
"""
Froze node type and module name.

After node_type is frozen, then the `module_name`
will be affected when `node_type` is `class`.
Thus, this line must be placed before `nd_inst.data.module_name`.

Args:
module_name: Modified module name.
node_type (str): Node type, class of func.

"""
if not self._type_frozen:
self._node_type = node_type
self._type_frozen = True

if not self._module_name_frozen:
self._froze_module_name(module_name)

def convert_successful(self):
"""
Whether convert successfully.

Returns:
bool, true or false.
"""
if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms:
return True
return False

+ 3
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -87,7 +87,8 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None
inputs_as_nchw=None
)
opt_map = getattr(optimizer.back_to_back_optimizer, '_func_map')
opt_map.pop(('Conv', 'BatchNormalization'))
if ('Conv', 'BatchNormalization') in opt_map:
opt_map.pop(('Conv', 'BatchNormalization'))
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model("converted from {}".format(model_path))

@@ -228,8 +229,7 @@ class OnnxNode(BaseNode):
"""

def __init__(self, raw_node):
super(OnnxNode, self).__init__(
node_name=raw_node.name, op_type=raw_node.op_type)
super(OnnxNode, self).__init__(node_name=raw_node.name, op_type=raw_node.op_type)
self.raw_node = raw_node
self.params = ParamsAttribute(raw_node.attribute, raw_node)
self.scope_name = None


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -99,8 +99,8 @@ class PyTorchGraph(Graph):

for item in input_shape:
if not isinstance(item, int):
err_msg = f"Only support model with one input now, " \
f"and each shape value in `input_shape` should be int."
err_msg = "Only support model with one input now, " \
"and each shape value in `input_shape` should be int."
log.error(err_msg)
raise ValueError(err_msg)



+ 19
- 156
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -13,14 +13,11 @@
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph node."""
from copy import deepcopy

from .base import GraphNode
from ..common.utils import is_converted

from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP, InputType
from ..mapper.base import Mapper
from ...common.exceptions import NodeInputTypeNotSupport
SEPARATOR_IN_ONNX_OP


class PyTorchGraphNode(GraphNode):
@@ -40,9 +37,6 @@ class PyTorchGraphNode(GraphNode):
self._op_params = self._get_raw_params(node)
self._op_name = node.kind() if node else None
self._scope_name = node.scopeName() if node else None
self._opt_var_name = None
self._variable_name = self._extract_var_name(self._scope_name)
self._module_name = None
self._weight = weight

def clear_args_of_declaration(self):
@@ -51,7 +45,7 @@ class PyTorchGraphNode(GraphNode):
"""
self._args_in_code = dict()

def _get_arg_name(self, arg):
def _get_arg_name(self, arg, variable_name):
"""
Get arg name.

@@ -61,7 +55,7 @@ class PyTorchGraphNode(GraphNode):
Returns:
str, arg name in function or class declaration.
"""
return f"{arg}_{self._variable_name}"
return f"{arg}_{variable_name}"

@property
def hash_key(self):
@@ -88,53 +82,6 @@ class PyTorchGraphNode(GraphNode):
"""
self._hash_key = h

@property
def variable_name(self):
"""
Variable name.

Returns:
str, variable name declared in init.
"""
return self._variable_name

@variable_name.setter
def variable_name(self, v):
"""
Setter of variable name.

Args:
v (str): Variable name.

"""
self._variable_name = v

@property
def module_name(self):
"""
Module name.

Returns:
str, module name.
"""
if not self._module_name_frozen:
module_name = self.tag
return module_name

return self._module_name

def _froze_module_name(self, m):
"""
Once module_name is set, then it's unchangeable.

Args:
m (str): Module name.

"""
if not self._module_name_frozen:
self._module_name = m
self._module_name_frozen = True

@property
def op_name(self):
"""
@@ -172,72 +119,47 @@ class PyTorchGraphNode(GraphNode):
self._ipt_shape = input_shape
self._opt_shape = output_shape

def to_code(self, ipt_args_in_construct: str, output_var: str):
def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment):
"""
Generate statements.

Args:
variable_name (str): Variable name.
ipt_args_in_construct (str): Args of input.
output_var (str): Output variable name in construct.
code_fragment (CodeFragment): CodeFragment instance.

Returns:
Union[str, str], declare in init and call in construct.
"""
operator = self.op_in_ms or self.module_name
self._opt_var_name = output_var
operator = code_fragment.operation

args = self.args_in_code
settings = self.settings_in_code
settings = code_fragment.code_setting

if self._node_type == NodeType.OPERATION.value and not self.convert_successful():
if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation):
args.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

if self._node_type == NodeType.OPERATION.value:
expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(ipt_args_in_construct,
settings)
ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(
ipt_args_in_construct, settings)
else:
# When it's type is module, class or func,
# it's not necessary to replace var.
expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}"
for k, v in args.items()])
ipt_args_settings_in_construct = ipt_args_in_construct

declare = f"self.{self._variable_name} = {operator}({expr})"
call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})"

return declare, call

@staticmethod
def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings):
"""
Generate input with args and settings in construct.

Args:
ipt_args_in_construct(str): input args in construct.
settings(dict): settings in operator.

"""
if settings.get('input_type'):
input_type = settings['input_type']
if input_type == InputType.TENSOR.value:
ipt_args_settings_in_construct = ipt_args_in_construct
elif input_type == InputType.LIST.value:
ipt_args_settings_in_construct = f"({ipt_args_in_construct})"
else:
raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.")
else:
ipt_args_settings_in_construct = ipt_args_in_construct
if SEPARATOR_IN_ONNX_OP in operator:
operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".")

if settings.get('values'):
settings_value = settings['values']
if settings_value:
settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()])
ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct))
declare = f"self.{variable_name} = {operator}({expr})"
call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})"

return ipt_args_settings_in_construct
return declare, call

def to_ir(self):
"""
@@ -288,62 +210,3 @@ class PyTorchGraphNode(GraphNode):
var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace(
RIGHT_BUCKET, "")
return var

def param_transform(self, mapper: Mapper):
"""
Transform torch params into mindspore.

Args:
mapper (Mapper): Mapper of params.

"""
if self._node_type != NodeType.OPERATION.value:
args = deepcopy(self._args_in_code)
self._args_in_code = dict()
for arg, value in args.items():
self._args_in_code[self._get_arg_name(arg)] = value
return None, None, None

if not self.transformed:
_, _, _ = super(PyTorchGraphNode, self).param_transform(mapper)

for arg, value in self._params_in_ms.items():
self._args_in_code[self._get_arg_name(arg)] = value

for arg, value in self._settings_in_ms.items():
self._settings_in_code[arg] = value

self.transformed = True

return self._op_in_ms, self._params_in_ms, self._settings_in_ms

def froze_node_type_and_module_name(self, node_type, module_name):
"""
Froze node type and module name.

After node_type is frozen, then the `module_name`
will be affected when `node_type` is `class`.
Thus, this line must be placed before `nd_inst.data.module_name`.

Args:
module_name: Modified module name.
node_type (str): Node type, class of func.

"""
if not self._type_frozen:
self._node_type = node_type
self._type_frozen = True

if not self._module_name_frozen:
self._froze_module_name(module_name)

def convert_successful(self):
"""
Whether convert successfully.

Returns:
bool, true or false.
"""
if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms:
return True
return False

+ 1
- 1
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py View File

@@ -42,7 +42,7 @@ class TestHierarchicalTree:
get_raw_params.return_value = []
tree = HierarchicalTree()
pt_node = PyTorchGraphNode()
tree.insert(pt_node, 'ResNet', (1, 3, 224, 224), (1, 64, 112, 112))
tree.insert(pt_node, 'ResNet')
assert tree.root == 'ResNet'

def test_remove(self):


+ 21
- 19
tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py View File

@@ -17,11 +17,13 @@ import numpy as np
import pytest

from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from tests.utils import mindspore


class TestMappers:
"""Test Mappers."""

@pytest.mark.parametrize('params', [{
'input': {'op_name': 'onnx::Conv',
'params': {'dilations': [1, 1],
@@ -38,7 +40,7 @@ class TestMappers:
'pad_mode': '\"pad\"',
'dilation': (1, 1),
'group': 1},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Conv',
'params': {'dilations': [1, 1],
@@ -55,7 +57,7 @@ class TestMappers:
'pad_mode': '\"valid\"',
'dilation': (1, 1),
'group': 1},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Gemm',
'params': dict(),
@@ -65,7 +67,7 @@ class TestMappers:
'converted_params': {'in_channels': 3,
'out_channels': 10,
'has_bias': True},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::BatchNormalization',
'params': {'epsilon': 1e-5,
@@ -76,14 +78,14 @@ class TestMappers:
'converted_params': {'num_features': 6,
'eps': 1e-5,
'momentum': 0.9},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Relu',
'params': dict(),
'weights': dict()},
'expected_output': {'converter_name': 'nn.ReLU',
'converted_params': dict(),
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::MaxPool',
'params': {'kernel_shape': [3, 3],
@@ -94,7 +96,7 @@ class TestMappers:
'converted_params': {'kernel_size': (3, 3),
'stride': (2, 2),
'pad_mode': '"same"'},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::AveragePool',
'params': {'kernel_shape': [3, 3],
@@ -105,7 +107,7 @@ class TestMappers:
'converted_params': {'kernel_size': (3, 3),
'stride': (2, 2),
'pad_mode': '"same"'},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::GlobalAveragePool',
'params': {'input_shape': (1, 3, 10, 10),
@@ -113,21 +115,21 @@ class TestMappers:
'weights': ''},
'expected_output': {'converter_name': 'nn.AvgPool2d',
'converted_params': {'kernel_size': (10, 10)},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Flatten',
'params': dict(),
'weights': dict()},
'expected_output': {'converter_name': 'nn.Flatten',
'converted_params': dict(),
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Add',
'params': dict(),
'weights': dict()},
'expected_output': {'converter_name': 'P.TensorAdd',
'converted_params': dict(),
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Pad',
'params': {'pads': [0, 1, 2, 3],
@@ -137,7 +139,7 @@ class TestMappers:
'expected_output': {'converter_name': 'nn.Pad',
'converted_params': {'paddings': ((0, 2), (1, 3)),
'mode': '\"CONSTANT\"'},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Pad',
'params': {'pads': [0, 1, 2, 3],
@@ -146,7 +148,7 @@ class TestMappers:
'expected_output': {'converter_name': 'nn.Pad',
'converted_params': {'paddings': ((0, 2), (1, 3)),
'mode': '\"REFLECT\"'},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Pad',
'params': {'pads': [0, 1, 2, 3],
@@ -156,7 +158,7 @@ class TestMappers:
'expected_output': {'converter_name': 'nn.Pad',
'converted_params': {'paddings': ((0, 2), (1, 3)),
'mode': '{UNSUPPORTED: value is NOT 0}\"CONSTANT\"'},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Pad',
'params': {'pads': [0, 1, 2, 3],
@@ -165,7 +167,7 @@ class TestMappers:
'expected_output': {'converter_name': 'nn.Pad',
'converted_params': {'paddings': ((0, 2), (1, 3)),
'mode': '{UNSUPPORTED: \"edge\"}\"UNKNOWN\"'},
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::ReduceMean',
'params': {'keepdims': 0,
@@ -196,14 +198,14 @@ class TestMappers:
'weights': dict()},
'expected_output': {'converter_name': 'nn.ReLU6',
'converted_params': dict(),
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Clip',
'params': dict(),
'weights': dict()},
'expected_output': {'converter_name': 'nn.ReLU',
'converted_params': dict(),
'converted_settings': dict()}
'converted_settings': Setting()}
}, {
'input': {'op_name': 'onnx::Clip',
'params': {'max': 3,
@@ -211,13 +213,13 @@ class TestMappers:
'weights': dict()},
'expected_output': {'converter_name': None,
'converted_params': dict(),
'converted_settings': dict()}
'converted_settings': Setting()}
}])
def test_mapper(self, params):
"""Test mapper function."""
mapper = ONNXToMindSporeMapper()
converter_name, converted_params, converted_settings = \
converter_name, converted_params, converted_settings, _ = \
mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights'])
assert params['expected_output']['converter_name'] == converter_name
assert params['expected_output']['converted_params'] == converted_params
assert params['expected_output']['converted_settings'] == converted_settings
assert isinstance(converted_settings, Setting)

Loading…
Cancel
Save