From 2efb39bfee510e389b19d5e28e861be5517db0d8 Mon Sep 17 00:00:00 2001 From: liuchongming Date: Wed, 25 Nov 2020 12:34:16 +0800 Subject: [PATCH] Refactor graph node and hierarchical tree. 1. Remove code generation related variables from graph node. 2. Create CodeFragment instance to manage code generation. 3. Add code gen setting instance to mapper. --- .../graph_based_converter/common/__init__.py | 15 ++ .../common/code_fragment.py | 218 ++++++++++++++++++ .../graph_based_converter/common/utils.py | 29 +++ .../hierarchical_tree/__init__.py | 9 +- .../hierarchical_tree/hierarchical_tree.py | 117 +++++----- .../hierarchical_tree/name_mgr.py | 4 + .../graph_based_converter/mapper/base.py | 10 +- .../mapper/gen_setting.py | 34 +++ .../mapper/impl/nn/batch_norm_mapper.py | 3 +- .../mapper/impl/nn/conv_mapper.py | 4 +- .../mapper/impl/nn/dense_mapper.py | 3 +- .../mapper/impl/nn/flatten_mapper.py | 3 +- .../mapper/impl/nn/global_pool_mapper.py | 6 +- .../mapper/impl/nn/mat_mul_mapper.py | 11 +- .../mapper/impl/nn/pad_mapper.py | 3 +- .../mapper/impl/nn/pool_mapper.py | 3 +- .../mapper/impl/nn/relu_mapper.py | 3 +- .../mapper/impl/nn/softmax_mapper.py | 3 +- .../mapper/impl/ops/add_mapper.py | 11 +- .../mapper/impl/ops/concat_mapper.py | 3 +- .../mapper/impl/ops/reduce_mean_mapper.py | 3 +- .../mapper/impl/ops/transpose_mapper.py | 3 +- .../third_party_graph/base.py | 182 +++++++-------- .../third_party_graph/graph_parser.py | 1 - .../third_party_graph/input_node.py | 26 +-- .../third_party_graph/onnx_graph.py | 1 - .../third_party_graph/onnx_graph_node.py | 215 +++-------------- .../third_party_graph/onnx_utils.py | 6 +- .../third_party_graph/pytorch_graph.py | 4 +- .../third_party_graph/pytorch_graph_node.py | 175 ++------------ .../test_hierarchical_tree.py | 2 +- .../mapper/test_mapper.py | 40 ++-- 32 files changed, 584 insertions(+), 566 deletions(-) create mode 100644 mindinsight/mindconverter/graph_based_converter/common/__init__.py create mode 100644 mindinsight/mindconverter/graph_based_converter/common/code_fragment.py create mode 100644 mindinsight/mindconverter/graph_based_converter/common/utils.py create mode 100644 mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py diff --git a/mindinsight/mindconverter/graph_based_converter/common/__init__.py b/mindinsight/mindconverter/graph_based_converter/common/__init__.py new file mode 100644 index 00000000..5abd50b8 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common instance and utils of graph based converter.""" diff --git a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py new file mode 100644 index 00000000..705e10ce --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py @@ -0,0 +1,218 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define CodeLine object.""" +import abc + + +class TrainableParams: + """Trainable parameters.""" + + def __init__(self, shape, dtype, reference): + self.param_name = None + self.shape = shape + self.dtype = dtype + self.reference = reference # Weight name in global npy. + + +class CodeSetting: + """Code generation settings.""" + + def __init__(self): + self.output_vars_suffix = [] + self.operation_input_type = None # Construct input type, tensor or list. + self.operation_extra_input = dict() # `values` in original setting dict. + self.operation_extra_tensor = None # For `MatMul`, `BiasAdd` op, need a tensor + + +class Fragment(abc.ABC): + """ + Define comment attributes of code generation. + + Args: + operation (str): Operation name in MindSpore. + actual_args (dict): Actual arg values. + settings (namedTuple): Code generation setting. + """ + + def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): + self._operation = operation + self._input_shape = input_shape + self._output_shape = output_shape + self._declared_variable_name = None + self._output_var_name = list() # Output variable name(could be multi-opt). + self._operation_inputs = list() # Index indices the order of input. + self._operation_extra_inputs = settings + self._code_setting = settings + self._formal_args_list = dict() + self._actual_args_list = actual_args # Key is the param_key, value is the corresponding value. + self._node_type = "" + + @property + def code_setting(self): + return self._code_setting + + @property + def node_type(self): + """Node type getter.""" + return self._node_type + + @node_type.setter + def node_type(self, t): + """Node type setter.""" + self._node_type = t + + @property + def operation_extra_inputs(self): + """Getter of extra operation inputs.""" + return self._operation_extra_inputs + + @property + def declared_var_name(self): + """Declared variable name getter.""" + return self._declared_variable_name + + @declared_var_name.setter + def declared_var_name(self, var): + """Setter of declared variable name.""" + self._declared_variable_name = var + + @property + def output_var_name(self) -> str: + """Getter of output variable name.""" + return ", ".join(self._output_var_name) + + @output_var_name.setter + def output_var_name(self, opt_vars): + """ + Output variable name setter. + + Args: + opt_vars (list[str]): Output variable name. + """ + self._output_var_name = opt_vars + + @property + def operation_inputs(self): + """ + Operation getter. + + Returns: + list[Fragment], list of inputs. + """ + return self._operation_inputs + + def update_operation_inputs(self, ipt): + """ + Update operation inputs. + + Args: + ipt (Fragment): Where input comes from. + """ + self._operation_inputs.append(ipt) + + @property + def operation(self): + """ + Operation getter. + + Returns: + str, operation name to be initialized. + """ + return self._operation + + @operation.setter + def operation(self, op: str): + """ + Operation setter. + + Args: + op (str): Operation name. + """ + self._operation = op + + @property + def actual_args(self) -> dict: + """Getter of actual args.""" + return self._actual_args_list + + @property + def formal_args(self) -> dict: + """Get formal args.""" + return self._formal_args_list + + def update_formal_args(self, formal_args: dict): + """ + Update formal args. + + Args: + formal_args (dict): To be updated args. + """ + return self._formal_args_list.update(formal_args) + + @property + def input_shape(self): + return self._input_shape + + @property + def output_shape(self): + return self._output_shape + + +class CodeFragment(Fragment): + """ + Manage the variables related with code generation. + + For single operation type node, the variables in `CodeLine` stands for: + ```python + class Module(nn.Cell): + def __init__ (self, ...): + super(Module, self).__init__() + self. = (, + ) + self. = Tensor(, + dtype=) + + def construct(self, x, ...): + = self.() + ... + return output + ``` + + Args: + operation (str): Operation name in MindSpore. + actual_args (dict): Actual arg values. + settings (namedTuple): Code generation setting. + """ + + def __init__(self, operation, actual_args, settings, input_shape, output_shape, + trainable_params=None): + super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, + input_shape=input_shape, output_shape=output_shape, + settings=settings) + self._trainable_params = dict() # External weights, like Matmul. + self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. + + @property + def trainable_params(self): + return self._trainable_params + + +class ModuleFragment(Fragment): + """Manage module type code variables.""" + + def __init__(self, operation, actual_args, settings, input_shape, output_shape): + super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args, + input_shape=input_shape, output_shape=output_shape, + settings=settings) diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py new file mode 100644 index 00000000..5fcbd82e --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -0,0 +1,29 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define common utils.""" +from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP + + +def is_converted(operation: str): + """ + Whether convert successful. + + Args: + operation (str): Operation name. + + Returns: + bool, true or false. + """ + return operation and SEPARATOR_IN_ONNX_OP not in operation diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py index 44613454..e06e781b 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Hierarchical tree module.""" import re + from mindinsight.mindconverter.common.log import logger as log from .hierarchical_tree import HierarchicalTree from ..third_party_graph.onnx_graph_node import OnnxGraphNode @@ -36,7 +37,6 @@ def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): """ scope_name = node.scope_name new_name = None - parent = "" regex = r"(?P.+/)(?P\w+)" match = re.match(regex, scope_name) parent = match.group("parent") @@ -74,12 +74,13 @@ class HierarchicalTreeFactory: f"Cannot find {node_name}'s input shape." log.error(err_msg) if isinstance(node_inst, OnnxGraphNode): - node_name_with_scope = _tf_model_node_name_reformat( - node_inst, node_name) + node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) node_scope_name[node_name] = node_name_with_scope node_name = node_name_with_scope - tree.insert(node_inst, node_name, node_input, node_output) + node_inst.add_input_and_output_shape(node_input, node_output) + tree.insert(node_inst, node_name) + if node_scope_name: return tree, node_scope_name return tree diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py index e67fc33a..a4802ec9 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py @@ -25,17 +25,18 @@ from treelib import Tree, Node from mindinsight.mindconverter.common.log import logger as log from .name_mgr import ModuleNameMgr, GlobalVarNameMgr +from ..common.utils import is_converted from ..mapper.base import Mapper from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode -from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig +from ..constant import SEPARATOR_IN_SCOPE +from ..constant import CodeFormatConfig +from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT from ..constant import NEW_LINE, SECOND_LEVEL_INDENT from ..constant import NodeType from ..report_generator import ReportGenerator from ...common.exceptions import NodeTypeNotSupport -GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() - class HierarchicalTree(Tree): """Define hierarchical tree.""" @@ -46,6 +47,8 @@ class HierarchicalTree(Tree): _root_created = False ROOT_LEVEL = 0 + GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() + def __init__(self): super(HierarchicalTree, self).__init__() self._hierarchical_order = dict() @@ -62,6 +65,7 @@ class HierarchicalTree(Tree): self._module_vars = dict() # scope name mapping record for easy node searching self._scope_name_map = dict() + self.code_fragment_recorder = dict() @property def tree_identifier(self): @@ -82,19 +86,15 @@ class HierarchicalTree(Tree): return None return self._nodes[nid] - def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], - node_name: str, input_shape, output_shape): + def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], node_name: str): """ Insert node into hierarchical tree. Args: node_name (str): Node name. node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. - output_shape (tuple): Output tensor shape. - input_shape (tuple): Input tensor shape. """ - node.add_input_and_output_shape(input_shape, output_shape) scopes = node_name.split(SEPARATOR_IN_SCOPE) for idx, scope in enumerate(scopes): parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) @@ -125,10 +125,9 @@ class HierarchicalTree(Tree): tgt_node.precursor_nodes = node.precursor_nodes tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 else NodeType.MODULE).value - tgt_node.tag = scope.split(SEPARATOR_BTW_NAME_AND_ID)[0] tgt_node.variable_name = self._get_var_name(identifier) self.create_node( - tag=tgt_node.tag, + tag=scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], identifier=identifier, parent=parent, data=tgt_node @@ -276,8 +275,7 @@ class HierarchicalTree(Tree): node.data.replace_with_arg(arg, arg) return node - @staticmethod - def _clear_unused_args(node, used_args): + def _clear_unused_args(self, node, used_args): """ Clear unused args. @@ -290,7 +288,9 @@ class HierarchicalTree(Tree): """ args_in_code = list(node.data.args_in_code.keys()) for arg in args_in_code: - ori_arg = arg.replace(f"_{node.data.variable_name}", "") + ori_arg = arg.replace( + f"_{self.code_fragment_recorder[node.identifier].declared_var_name}", "" + ) if ori_arg not in used_args: node.data.args_in_code.pop(arg) return node @@ -323,6 +323,8 @@ class HierarchicalTree(Tree): # 1. Generate args for each node in this level. if node.data.node_type == NodeType.MODULE.value: self._create_module_args_and_vars(node, mapper) + if depth == depths[-1]: + self.code_fragment_recorder[node.identifier] = node.data.param_transform(mapper, "") # Module merging based on all nodes. self._module_merging() @@ -345,30 +347,29 @@ class HierarchicalTree(Tree): # then assign the created module name to current node, # and delete unused args. module_name = self._created_module[module_key] - nd_inst.data.froze_node_type_and_module_name(node_type, - module_name) + self.code_fragment_recorder[nd_inst.identifier].operation = module_name + self.code_fragment_recorder[nd_inst.identifier].node_type = node_type self._preprocess_node_args(nd_inst, module_key) continue - module_name = nd_inst.data.module_name + module_name = nd_inst.tag + if node_type == NodeType.CLASS.value: module_name = f"{module_name[0].upper()}{module_name[1:]}" # After node_type and module_name is frozen, # then it's unchangeable. module_name = self._module_mgr.get_name(module_name) - nd_inst.data.froze_node_type_and_module_name(node_type, - module_name) + self.code_fragment_recorder[nd_inst.identifier].operation = module_name + self.code_fragment_recorder[nd_inst.identifier].node_type = node_type # 3. Pre-process node args. nd_inst = self._preprocess_node_args(nd_inst, module_key) # 4. Post-process child node args. for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): - self._postprocess_node_args( - self.get_node(scsr_nd_name), module_key) + self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) # 5. Generate code. - snippets.add( - func(nd_inst, nd_inst.data.module_name, module_key)) + snippets.add(func(nd_inst, self.code_fragment_recorder[nd_inst.identifier].operation, module_key)) code_blocks.extend(snippets) @@ -437,7 +438,7 @@ class HierarchicalTree(Tree): module_list = [] for node_name in node.successors(self.tree_identifier): c_nd = self.get_node(node_name) - operator = c_nd.data.op_in_ms or c_nd.data.module_name + operator = self.code_fragment_recorder[c_nd.identifier].operation if c_nd.data.node_type != NodeType.OPERATION.value: hash_key = c_nd.data.hash_key or self.hash_key(c_nd) @@ -445,14 +446,16 @@ class HierarchicalTree(Tree): operator = self._created_module[hash_key] args = c_nd.data.args_in_code - if c_nd.data.node_type == NodeType.OPERATION.value and \ - not c_nd.data.convert_successful(): + if c_nd.data.node_type == NodeType.OPERATION.value and not is_converted( + self.code_fragment_recorder[c_nd.identifier].operation): args.update({"input_shape": c_nd.data.input_shape, "output_shape": c_nd.data.output_shape}) # Generate code statement. - expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}" - for k, v in args.items()]) + expr = ", ".join( + [f"{k.replace(f'_{self.code_fragment_recorder[c_nd.identifier].declared_var_name}', '')}={v}" + for k, v in args.items()] + ) code_line = f"{operator}({expr})" module_list.append(code_line) @@ -547,14 +550,16 @@ class HierarchicalTree(Tree): if idx != 0: # Get previous node output variable name. - ipt_args_in_construct = self._get_previous_opt_var( - cur_nd_inst, pre_nd_inst) + ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: # Set opt variable name. - opt_arg_in_construct = cur_nd_inst.data.opt_var_name + opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, - output_var=opt_arg_in_construct) + variable_name=self.code_fragment_recorder[ + cur_nd_inst.identifier].declared_var_name, + output_var=opt_arg_in_construct, + code_fragment=self.code_fragment_recorder[cur_nd_inst.identifier]) return declare, call @@ -588,7 +593,9 @@ class HierarchicalTree(Tree): if e not in pre_nd.successors(self.tree_identifier): while True: if p_nd.identifier in pre_nd.successors(self.tree_identifier): - ipt_lst.append(p_nd.data.opt_var_name) + ipt_lst.append( + f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" + ) break pre_nd_name = p_nd.predecessor(self.tree_identifier) if not pre_nd_name: @@ -597,7 +604,9 @@ class HierarchicalTree(Tree): p_nd = self.get_node(pre_nd_name) continue - ipt_lst.append(p_nd.data.opt_var_name) + ipt_lst.append( + f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" + ) return ipt_lst def _get_previous_opt_var(self, cur_nd, pre_nd): @@ -619,12 +628,11 @@ class HierarchicalTree(Tree): cur_nd = self.get_node(p_nd[0]) return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) - def hash_key(self, node, depth: int = 0): + def hash_key(self, node): """ Generate hash key for each node. Args: - depth (int): Recursion depth. node (Node): Node. Returns: @@ -633,13 +641,17 @@ class HierarchicalTree(Tree): scsr_topo_order = [] for s in node.successors(self.tree_identifier): cur_nd = self.get_node(s) - if cur_nd.data.hash_key: - scsr_topo_order.append(f"{cur_nd.data.hash_key}[{depth}]") - continue if cur_nd.data.node_type in {NodeType.MODULE.value, NodeType.FUNC.value, NodeType.CLASS.value}: - scsr_topo_order.append(self.hash_key(cur_nd, depth + 1)) + if cur_nd.data.hash_key: + scsr_topo_order.append(f"({cur_nd.data.hash_key})") + continue + + raise ValueError("Current node doesn't have hash key.") + + if cur_nd.data.hash_key: + scsr_topo_order.append(cur_nd.data.hash_key) continue unique_key = "->".join(scsr_topo_order) node.data.hash_key = unique_key @@ -675,12 +687,11 @@ class HierarchicalTree(Tree): """ # All args and value pair in current node module. module_args = dict() - module_settings = dict() module_key = self.hash_key(node) created = False if module_key not in self._vars_mgr_in_module: - self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR + self._vars_mgr_in_module[module_key] = self.GLOBAL_VAR_NAME_MGR self._module_vars[module_key] = [] else: created = True @@ -688,33 +699,29 @@ class HierarchicalTree(Tree): # Sub-modules in the module could have arg name conflicts. for idx, successor_name in enumerate(node.successors(self.tree_identifier)): nd_inst = self.get_node(successor_name) - # Generate variable name here, then - # to generate args. + # Generation of params must behind variable assigment. if created: - nd_inst.data.variable_name = self._module_vars[module_key][idx] + variable_name = self._module_vars[module_key][idx] else: - variable_name = nd_inst.data.op_name or nd_inst.data.module_name - variable_name = self._vars_mgr_in_module[module_key].get_name( - variable_name) - nd_inst.data.variable_name = variable_name + variable_name = nd_inst.data.op_name or nd_inst.tag + variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) - # Generation of params must behind variable assigment. - nd_inst.data.param_transform(mapper) + code_fragment = nd_inst.data.param_transform(mapper, variable_name) + code_fragment.declared_var_name = variable_name + self.code_fragment_recorder[nd_inst.identifier] = code_fragment module_args.update(nd_inst.data.args_in_code) - module_settings.update(nd_inst.data.settings_in_code) if not created: - self._module_vars[module_key].append( - nd_inst.data.variable_name) + self._module_vars[module_key].append(variable_name) node.data.args_in_code = module_args # Collect module args of `module_key`. if module_key not in self._merged_module: - self._merged_module[module_key] = [node.data.args_in_code] + self._merged_module[module_key] = [deepcopy(node.data.args_in_code)] else: - self._merged_module[module_key].append(node.data.args_in_code) + self._merged_module[module_key].append(deepcopy(node.data.args_in_code)) @staticmethod def _create_operation_args(node, mapper): diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py index 62d8aab4..2811f762 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py @@ -63,6 +63,10 @@ START_IDX = 0 class GlobalVarNameMgr: """Global variable name mgr.""" + def __init__(self): + global_op_namespace.clear() + global_var_namespace.clear() + @staticmethod def _get_name(name): """Deal with op name.""" diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/base.py b/mindinsight/mindconverter/graph_based_converter/mapper/base.py index c1dba8bc..e0871852 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/base.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/base.py @@ -87,7 +87,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): module_name = TABLE.get(op_name) if not module_name: - return None, dict(), dict() + return None, dict(), None, dict() pos = module_name.rfind(".") try: @@ -101,7 +101,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): # If mapper can not be found, then skip it. err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) - return None, dict(), dict() + return None, dict(), None, dict() try: converter_name = op_name_converter( @@ -110,13 +110,13 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): converted_weights = weights_converter( weights=weights) if weights else dict() converted_params.update(converted_weights) - converted_settings = settings_converter(params=params) + converted_settings = settings_converter(params=params, weights=weights) except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) - return None, dict(), dict() + return None, dict(), None, dict() - return converter_name, converted_params, converted_settings + return converter_name, converted_params, converted_settings, converted_weights @staticmethod def _operation_name_in_ms(*args, **kwargs): diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py b/mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py new file mode 100644 index 00000000..08a5f98c --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/gen_setting.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operation mapping setting.""" +from collections import namedtuple +import numpy as np + +from mindinsight.mindconverter.graph_based_converter.constant import InputType + +Tensor = namedtuple("Tensor", ["shape", "dtype", "reference"]) + +Setting = namedtuple("Setting", ["opt_vars_suffix", + "op_ipt_type", + "op_extra_input", + "op_extra_tensor"]) +Setting.__new__.__defaults__ = ("_opt", InputType.TENSOR.value, dict(), None) + + +def get_dtype(tensor: np.ndarray): + """Get tensor dtype.""" + if tensor.dtype == np.float16: + return "mindspore.float16" + return "mindspore.float32" diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py index 2f8d0e57..5d02b9b2 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class BatchNormMapper(ONNXToMindSporeMapper): @@ -39,4 +40,4 @@ class BatchNormMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py index 6513a7c8..9a9a2200 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py @@ -16,6 +16,7 @@ import re import numpy as np from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting def _convert_padding(**kwargs): @@ -35,6 +36,7 @@ def _convert_padding(**kwargs): class ConvMapper(ONNXToMindSporeMapper): """Conv2d mapper.""" + @staticmethod def convert_params_torch(**kwargs): """Convert params from PyTorch to MindSpore""" @@ -148,4 +150,4 @@ class ConvMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py index 18f68716..2f4eb387 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class DenseMapper(ONNXToMindSporeMapper): @@ -41,4 +42,4 @@ class DenseMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py index 679ea812..024cf499 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class FlattenMapper(ONNXToMindSporeMapper): @@ -33,4 +34,4 @@ class FlattenMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py index c29a971d..29bc8550 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class GlobalPoolMapper(ONNXToMindSporeMapper): @@ -25,8 +26,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): op_name = 'nn.AvgPool{}d' else: op_name = 'nn.MaxPool{}d' - dim = 1 if len(kwargs['params']['input_shape']) == 3\ - else 2 + dim = 1 if len(kwargs['params']['input_shape']) == 3 else 2 return op_name.format(dim) @staticmethod @@ -49,4 +49,4 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py index 2d2782cf..603a4fd8 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting, Tensor, get_dtype class MatMulMapper(ONNXToMindSporeMapper): @@ -33,4 +34,12 @@ class MatMulMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + weights = kwargs.get("weights") + if not weights: + return Setting() + tensor, ref = None, "" + for t_name, t_value in weights.items(): + tensor = t_value + ref = t_name + return Setting(op_extra_tensor=Tensor(shape=tensor.shape, + dtype=get_dtype(tensor), reference=ref)) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py index e0ff4225..4c3e0715 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting def _padding_format_convert(padding: list): @@ -77,4 +78,4 @@ class PadMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py index 1c248a75..b33ed715 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class PoolMapper(ONNXToMindSporeMapper): @@ -49,4 +50,4 @@ class PoolMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py index b5a24717..e89052ca 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class ReLUMapper(ONNXToMindSporeMapper): @@ -45,4 +46,4 @@ class ReLUMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py index fbbe781a..be029109 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/softmax_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class SoftmaxMapper(ONNXToMindSporeMapper): @@ -37,4 +38,4 @@ class SoftmaxMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py index 7b6dea75..83808984 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting, Tensor, get_dtype class AddMapper(ONNXToMindSporeMapper): @@ -33,4 +34,12 @@ class AddMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): - return dict() + weights = kwargs.get("weights") + if not weights: + return Setting() + tensor, ref = None, "" + for t_name, t_value in weights.items(): + tensor = t_value + ref = t_name + return Setting(op_extra_tensor=Tensor(shape=tensor.shape, + dtype=get_dtype(tensor), reference=ref)) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py index b0a32a9e..eb1205f9 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py @@ -15,6 +15,7 @@ """Mapper module.""" from mindinsight.mindconverter.graph_based_converter.constant import InputType from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class ConcatMapper(ONNXToMindSporeMapper): @@ -36,4 +37,4 @@ class ConcatMapper(ONNXToMindSporeMapper): @staticmethod def _convert_settings(**kwargs): input_type = InputType.LIST.value - return {'input_type': input_type} + return Setting(op_ipt_type=input_type) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py index 68623457..239d07a6 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class ReduceMeanMapper(ONNXToMindSporeMapper): @@ -40,4 +41,4 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): axis = params['axes'][0] if len(params['axes']) == 1 else tuple(params['axes']) else: axis = tuple() - return {'values': {'axis': axis}} + return Setting(op_extra_input={'axis': axis}) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py index cb51153c..d294d9d1 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/transpose_mapper.py @@ -14,6 +14,7 @@ # ============================================================================== """Mapper module.""" from ...base import ONNXToMindSporeMapper +from ...gen_setting import Setting class TransposeMapper(ONNXToMindSporeMapper): @@ -40,4 +41,4 @@ class TransposeMapper(ONNXToMindSporeMapper): perm = tuple(perm) converted_params['input_perm'] = perm - return {'values': converted_params} + return Setting(op_extra_input=converted_params) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index 811f1f8a..34e63569 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -15,10 +15,13 @@ """Define graph entity.""" import abc from collections import OrderedDict +from copy import deepcopy from mindinsight.mindconverter.common.log import logger as log -from ..constant import SEPARATOR_IN_ONNX_OP +from ..common.code_fragment import CodeFragment +from ..constant import NodeType, InputType from ..mapper.base import Mapper +from ...common.exceptions import NodeInputTypeNotSupport class GraphParser(metaclass=abc.ABCMeta): @@ -287,26 +290,10 @@ class GraphNode(abc.ABC): self._op_params = dict() self._scope_name = None self._op_shape = None - # Operation in mindspore. - self._op_in_ms = None - # Params in mindspore. - self._params_in_ms = dict() - # Settings in mindspore. - self._settings_in_ms = dict() # Node type of current node, e.g. class, module, operation. self._node_type = None - # Tag name on tree. - self._tag_on_tree = None # Function, class or operation needed args. self._args_in_code = dict() - # Operation needed settings. - self._settings_in_code = dict() - # Variable name declared in init block. - self._variable_name = None - # Output variable name declared in construct block. - self._opt_var_name = None - # Function or class name in code. - self._module_name = None # Unique key of node. self._hash_key = None # Input shape of current op. @@ -317,37 +304,18 @@ class GraphNode(abc.ABC): self._weight = None @property - def opt_var_name(self): + def weight(self): + return self._weight + + @staticmethod + def get_opt_var_name(variable_name): """ Output variable name. Returns: str, variable name. """ - return f"{self.variable_name}_opt" - - @opt_var_name.setter - def opt_var_name(self, v): - """ - Set variable name. - - Args: - v (str): Name. - - """ - self._opt_var_name = v - - @property - def op_in_ms(self): - """ - Operation in mindspore. - - Returns: - str, operation name. - """ - if self._op_in_ms and SEPARATOR_IN_ONNX_OP in self._op_in_ms: - return self._op_in_ms.replace(SEPARATOR_IN_ONNX_OP, ".") - return self._op_in_ms + return f"{variable_name}_opt" @property def args_in_code(self): @@ -370,27 +338,6 @@ class GraphNode(abc.ABC): """ self._args_in_code = args - @property - def settings_in_code(self): - """ - Settings in code. - - Returns: - dict, settings. - """ - return self._settings_in_code - - @settings_in_code.setter - def settings_in_code(self, settings): - """ - Settings in code. - - Args: - settings(dict): Settings. - - """ - self._settings_in_code = settings - @property def input_shape(self): """ @@ -411,16 +358,6 @@ class GraphNode(abc.ABC): """ return self._opt_shape - @property - def tag(self): - """Tag on hierarchical tree.""" - return self._tag_on_tree - - @tag.setter - def tag(self, t): - """Tag on hierarchical tree.""" - self._tag_on_tree = t - def is_empty(self): """ Whether is empty. @@ -536,7 +473,7 @@ class GraphNode(abc.ABC): """Replace actual parameter with formal parameter.""" @abc.abstractmethod - def _get_arg_name(self, arg): + def _get_arg_name(self, arg, variable_name): """Get arg name for func or class.""" @abc.abstractmethod @@ -553,13 +490,8 @@ class GraphNode(abc.ABC): def real_name(self, **kwargs): """Setter of `real_name`.""" - @property - @abc.abstractmethod - def variable_name(self): - """Getter of `variable_name`.""" - @abc.abstractmethod - def to_code(self, ipt_args_in_construct: str, output_var: str): + def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): """Graph node to MindSpore code.""" @abc.abstractmethod @@ -570,40 +502,86 @@ class GraphNode(abc.ABC): def add_input_and_output_shape(self, input_shape, output_shape): """Add the node input shape.""" - @abc.abstractmethod - def froze_node_type_and_module_name(self, node_type, module_name): - """Make node_type can not be changed.""" + @staticmethod + def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): + """ + Generate input with args and settings in construct. - @abc.abstractmethod - def convert_successful(self): - """Whether convert successful.""" + Args: + ipt_args_in_construct (str): Input args in construct. + settings (Setting): Settings in operator. + + Returns: + str, args of each node in generated construct statement. + """ + if settings and settings.op_ipt_type: + input_type = settings.op_ipt_type + if input_type == InputType.TENSOR.value: + ipt_args_settings_in_construct = ipt_args_in_construct + elif input_type == InputType.LIST.value: + ipt_args_settings_in_construct = f"({ipt_args_in_construct})" + else: + raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") + else: + ipt_args_settings_in_construct = ipt_args_in_construct + + if settings and settings.op_extra_input: + settings_value = settings.op_extra_input + if settings_value: + settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) + ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) - def param_transform(self, mapper: Mapper): + return ipt_args_settings_in_construct + + def param_transform(self, mapper: Mapper, variable_name): """ - Transform param in pytorch operation into mindspore. + Transform param in PyTorch operation into MindSpore. Args: + variable_name (str): Variable name. mapper (ONNXToMindSporeMapper): Mapper between onnx operation - and mindspore. + and MindSpore. Returns: dict, transformed params. """ - import copy - params = copy.deepcopy(self._op_params) + if self._node_type != NodeType.OPERATION.value: + args = deepcopy(self._args_in_code) + self._args_in_code = dict() + for arg, value in args.items(): + self._args_in_code[self._get_arg_name(arg, variable_name)] = value + return CodeFragment(operation="", actual_args=args, settings=None, + input_shape=self.input_shape, output_shape=self.output_shape) + + if self.transformed: + raise ValueError("Already transformed.") + + params = deepcopy(self._op_params) params.update({"input_shape": self.input_shape, "output_shape": self.output_shape}) - op_name_in_mindspore, ms_params, ms_settings = mapper.convert(op_name=self.op_name, - params=params, - weights=self._weight) - if op_name_in_mindspore: - self._op_in_ms = op_name_in_mindspore - self._params_in_ms = ms_params - self._settings_in_ms = ms_settings + ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name, + params=params, + weights=self._weight) + + if ms_op: + code_fragment = CodeFragment(operation=ms_op, + actual_args=ms_params, + settings=ms_settings, + input_shape=self.input_shape, + output_shape=self.output_shape, + trainable_params=ms_weights) else: - self._op_in_ms = self._op_name - self._params_in_ms = self._op_params - self._settings_in_ms = dict() + code_fragment = CodeFragment(operation=self._op_name, + actual_args=self._op_params, + settings=None, + input_shape=self.input_shape, + output_shape=self.output_shape, + trainable_params=self._weight) + + for arg, value in code_fragment.actual_args.items(): + self._args_in_code[self._get_arg_name(arg, variable_name)] = value + + self.transformed = True - return self._op_in_ms, self._params_in_ms, self._settings_in_ms + return code_fragment diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py index 5a6644cf..9ec692c3 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py @@ -38,7 +38,6 @@ class PyTorchGraphParser(GraphParser): error = FileNotFoundError("`model_path` must be assigned with " "an existed file path.") log.error(str(error)) - log.exception(error) raise error try: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py index be93ad03..e92d7035 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py @@ -21,24 +21,18 @@ from ..constant import SEPARATOR_IN_SCOPE, NodeType class InputNode(GraphNode): """ - Pytorch Input Node. + PyTorch Input Node. Args: input_shape: Input shape of module. """ - def convert_successful(self): - """ - Whether convert successful. - - Returns: - bool, true or false. - """ - return False + def _get_arg_name(self, arg, variable_name): + raise NotImplementedError() - def froze_node_type_and_module_name(self, node_type, module_name): - pass + def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): + raise NotImplementedError() def _get_raw_params(self, node): pass @@ -56,9 +50,6 @@ class InputNode(GraphNode): def replace_with_arg(self, src_arg, tgt_arg): pass - def _get_arg_name(self, arg): - pass - def add_input_and_output_shape(self, input_shape, output_shape): pass @@ -116,15 +107,8 @@ class InputNode(GraphNode): def real_name(self): return - @property - def variable_name(self): - return - def to_ir(self): """ No need to implement for now. """ raise NotImplementedError() - - def to_code(self, ipt_args_in_construct: str, output_var: str): - raise NotImplementedError() diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 27e52c68..f4befbc8 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -22,7 +22,6 @@ from .onnx_graph_node import OnnxGraphNode from .graph_parser import TFGraphParser from .onnx_utils import OnnxDataLoader - NONE_SCOPE_OP = { "onnx::Add": "Add", "onnx::Flatten": "Flatten", diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py index 4b4383e0..62cab356 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py @@ -13,14 +13,13 @@ # limitations under the License. # ============================================================================== """Define ONNX graph node.""" +from importlib import import_module -from copy import deepcopy from .base import GraphNode +from ..common.utils import is_converted -from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ - SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType -from ..mapper.base import Mapper -from ...common.exceptions import NodeInputTypeNotSupport +from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ + SEPARATOR_IN_ONNX_OP class OnnxGraphNode(GraphNode): @@ -39,16 +38,13 @@ class OnnxGraphNode(GraphNode): self._op_params = self._get_raw_params(node.raw_node) if node else None self._op_name = "onnx::" + node.op_type if node else None self._scope_name = node.scope_name if node else None - self._opt_var_name = None - self._variable_name = self._extract_var_name(self._scope_name) - self._module_name = None self._weight = weight def clear_args_of_declaration(self): """Clear `self._args_in_code`.""" self._args_in_code = dict() - def _get_arg_name(self, arg): + def _get_arg_name(self, arg, variable_name): """ Get arg name. @@ -58,7 +54,7 @@ class OnnxGraphNode(GraphNode): Returns: str, arg name in function or class declaration. """ - return f"{arg}_{self._variable_name}" + return f"{arg}_{variable_name}" @property def hash_key(self): @@ -84,51 +80,6 @@ class OnnxGraphNode(GraphNode): """ self._hash_key = h - @property - def variable_name(self): - """ - Variable name. - - Returns: - str, variable name declared in init. - """ - return self._variable_name - - @variable_name.setter - def variable_name(self, v): - """ - Setter of variable name. - - Args: - v (str): Variable name. - """ - self._variable_name = v - - @property - def module_name(self): - """ - Module name. - - Returns: - str, module name. - """ - if not self._module_name_frozen: - module_name = self.tag - return module_name - - return self._module_name - - def _froze_module_name(self, m): - """ - Once module_name is set, then it's unchangeable. - - Args: - m (str): Module name. - """ - if not self._module_name_frozen: - self._module_name = m - self._module_name_frozen = True - @property def op_name(self): """ @@ -154,15 +105,13 @@ class OnnxGraphNode(GraphNode): self._ipt_shape = input_shape self._opt_shape = output_shape - def _add_tensor_args_to_code(self, op_name: str, t_identifier: str, declare, args): + def _add_tensor_args_to_code(self, op_name: str, settings, declare, args, variable_name): """ Add nn used tensors to args in init and construct blocks. Args: op_name (str): Add the tensor to args if the current node has this - op_name. - t_identifier (str): The unique string appeared in the target tensor - name. + op_name. declare (str): Declare statement generated in to_code(). args (str): Args statement generated in to_code(). @@ -172,103 +121,68 @@ class OnnxGraphNode(GraphNode): """ if not self._op_name == op_name: return declare, args - declare_list = [] - tensor = None - # find target tensor - for t_name, t_value in self._weight.items(): - if t_identifier in t_name: - tensor = t_value - break - if tensor is None: + if not settings or not settings.op_extra_tensor: return declare, args - declare_list.append(declare) - declare_t = f"self.{self._variable_name}_w = Tensor(" \ - f"np.random.uniform(0, 1, {str(tensor.shape)}), mindspore.float32)" + declare_list = [declare] + declare_t = f"self.{variable_name}_w = Tensor(" \ + f"np.random.uniform(0, 1, {str(settings.op_extra_tensor.shape)}), " \ + f"{settings.op_extra_tensor.dtype})" declare_list.append(declare_t) - args += f", self.{self._variable_name}_w" + args += f", self.{variable_name}_w" return declare_list, args - def to_code(self, ipt_args_in_construct: str, output_var: str): + def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, + code_fragment): """ Generate statements. Args: + variable_name (str): Variable name. ipt_args_in_construct (str): Args of input. output_var (str): Output variable name in construct. + code_fragment (CodeFragment): CodeFragment instance. Returns: Union[str, str], declare in init and call in construct. """ - operator = self.op_in_ms or self.module_name - self._opt_var_name = output_var + operator = code_fragment.operation args = self.args_in_code - settings = self.settings_in_code - if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): + settings = code_fragment.code_setting + + if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): args.update({"input_shape": self.input_shape, "output_shape": self.output_shape}) if self._node_type == NodeType.OPERATION.value: - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" for k, v in args.items()]) - ipt_args_settings_in_construct = \ - self._generate_ipt_args_settings_in_construct( - ipt_args_in_construct, - settings) + ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( + ipt_args_in_construct, settings) else: # When it's type is module, class or func, # it's not necessary to replace var. - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" for k, v in args.items()]) ipt_args_settings_in_construct = ipt_args_in_construct - declare = f"self.{self._variable_name} = {operator}({expr})" + + if SEPARATOR_IN_ONNX_OP in operator: + operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") + + declare = f"self.{variable_name} = {operator}({expr})" # Extra Tensor generator for nn.MatMul declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( - 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) + 'onnx::MatMul', settings, declare, ipt_args_settings_in_construct, variable_name) # Extra Tensor generator for onnx::Add declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( - 'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct) + 'onnx::Add', settings, declare, ipt_args_settings_in_construct, variable_name) - call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" + call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" return declare, call - @staticmethod - def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): - """ - Generate input with args and settings in construct. - - Args: - ipt_args_in_construct(str): Input args in construct. - settings(dict): Settings in operator. - - Returns: - str, args of each node in generated construct statement. - """ - if settings.get('input_type'): - input_type = settings['input_type'] - if input_type == InputType.TENSOR.value: - ipt_args_settings_in_construct = ipt_args_in_construct - elif input_type == InputType.LIST.value: - ipt_args_settings_in_construct = f"({ipt_args_in_construct})" - else: - raise NodeInputTypeNotSupport( - f"Input type[{input_type}] is not supported now.") - else: - ipt_args_settings_in_construct = ipt_args_in_construct - - if settings.get('values'): - settings_value = settings['values'] - if settings_value: - settings_in_construct = ', '.join( - [f"{setting_val}" for _, setting_val in settings_value.items()]) - ipt_args_settings_in_construct = ', '.join( - (ipt_args_settings_in_construct, settings_in_construct)) - - return ipt_args_settings_in_construct - def to_ir(self): """No need to implement for now.""" raise NotImplementedError @@ -284,7 +198,7 @@ class OnnxGraphNode(GraphNode): Returns: dict, raw params. """ - import onnx + onnx = import_module("onnx") raw_params = dict() @@ -318,62 +232,3 @@ class OnnxGraphNode(GraphNode): var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( RIGHT_BUCKET, "") return var - - def param_transform(self, mapper: Mapper): - """ - Transform tensorflow params into mindspore. - - Args: - mapper (Mapper): Mapper of params. - - """ - if self._node_type != NodeType.OPERATION.value: - args = deepcopy(self._args_in_code) - self._args_in_code = dict() - for arg, value in args.items(): - self._args_in_code[self._get_arg_name(arg)] = value - return None, None - - if not self.transformed: - _, _, _ = super(OnnxGraphNode, self).param_transform(mapper) - - for arg, value in self._params_in_ms.items(): - self._args_in_code[self._get_arg_name(arg)] = value - - for arg, value in self._settings_in_ms.items(): - self._settings_in_code[arg] = value - - self.transformed = True - - return self._op_in_ms, self._params_in_ms, self._settings_in_ms - - def froze_node_type_and_module_name(self, node_type, module_name): - """ - Froze node type and module name. - - After node_type is frozen, then the `module_name` - will be affected when `node_type` is `class`. - Thus, this line must be placed before `nd_inst.data.module_name`. - - Args: - module_name: Modified module name. - node_type (str): Node type, class of func. - - """ - if not self._type_frozen: - self._node_type = node_type - self._type_frozen = True - - if not self._module_name_frozen: - self._froze_module_name(module_name) - - def convert_successful(self): - """ - Whether convert successfully. - - Returns: - bool, true or false. - """ - if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: - return True - return False diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 00349046..348bc658 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -87,7 +87,8 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None inputs_as_nchw=None ) opt_map = getattr(optimizer.back_to_back_optimizer, '_func_map') - opt_map.pop(('Conv', 'BatchNormalization')) + if ('Conv', 'BatchNormalization') in opt_map: + opt_map.pop(('Conv', 'BatchNormalization')) onnx_graph = optimizer.optimize_graph(g) model_proto = onnx_graph.make_model("converted from {}".format(model_path)) @@ -228,8 +229,7 @@ class OnnxNode(BaseNode): """ def __init__(self, raw_node): - super(OnnxNode, self).__init__( - node_name=raw_node.name, op_type=raw_node.op_type) + super(OnnxNode, self).__init__(node_name=raw_node.name, op_type=raw_node.op_type) self.raw_node = raw_node self.params = ParamsAttribute(raw_node.attribute, raw_node) self.scope_name = None diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index 0e4e6164..ae75de5c 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -99,8 +99,8 @@ class PyTorchGraph(Graph): for item in input_shape: if not isinstance(item, int): - err_msg = f"Only support model with one input now, " \ - f"and each shape value in `input_shape` should be int." + err_msg = "Only support model with one input now, " \ + "and each shape value in `input_shape` should be int." log.error(err_msg) raise ValueError(err_msg) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py index 4d6a9428..4153a9ca 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py @@ -13,14 +13,11 @@ # limitations under the License. # ============================================================================== """Define PyTorch graph node.""" -from copy import deepcopy - from .base import GraphNode +from ..common.utils import is_converted from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ - SEPARATOR_IN_ONNX_OP, InputType -from ..mapper.base import Mapper -from ...common.exceptions import NodeInputTypeNotSupport + SEPARATOR_IN_ONNX_OP class PyTorchGraphNode(GraphNode): @@ -40,9 +37,6 @@ class PyTorchGraphNode(GraphNode): self._op_params = self._get_raw_params(node) self._op_name = node.kind() if node else None self._scope_name = node.scopeName() if node else None - self._opt_var_name = None - self._variable_name = self._extract_var_name(self._scope_name) - self._module_name = None self._weight = weight def clear_args_of_declaration(self): @@ -51,7 +45,7 @@ class PyTorchGraphNode(GraphNode): """ self._args_in_code = dict() - def _get_arg_name(self, arg): + def _get_arg_name(self, arg, variable_name): """ Get arg name. @@ -61,7 +55,7 @@ class PyTorchGraphNode(GraphNode): Returns: str, arg name in function or class declaration. """ - return f"{arg}_{self._variable_name}" + return f"{arg}_{variable_name}" @property def hash_key(self): @@ -88,53 +82,6 @@ class PyTorchGraphNode(GraphNode): """ self._hash_key = h - @property - def variable_name(self): - """ - Variable name. - - Returns: - str, variable name declared in init. - """ - return self._variable_name - - @variable_name.setter - def variable_name(self, v): - """ - Setter of variable name. - - Args: - v (str): Variable name. - - """ - self._variable_name = v - - @property - def module_name(self): - """ - Module name. - - Returns: - str, module name. - """ - if not self._module_name_frozen: - module_name = self.tag - return module_name - - return self._module_name - - def _froze_module_name(self, m): - """ - Once module_name is set, then it's unchangeable. - - Args: - m (str): Module name. - - """ - if not self._module_name_frozen: - self._module_name = m - self._module_name_frozen = True - @property def op_name(self): """ @@ -172,72 +119,47 @@ class PyTorchGraphNode(GraphNode): self._ipt_shape = input_shape self._opt_shape = output_shape - def to_code(self, ipt_args_in_construct: str, output_var: str): + def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): """ Generate statements. Args: + variable_name (str): Variable name. ipt_args_in_construct (str): Args of input. output_var (str): Output variable name in construct. + code_fragment (CodeFragment): CodeFragment instance. Returns: Union[str, str], declare in init and call in construct. """ - operator = self.op_in_ms or self.module_name - self._opt_var_name = output_var + operator = code_fragment.operation args = self.args_in_code - settings = self.settings_in_code + settings = code_fragment.code_setting - if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): + if self._node_type == NodeType.OPERATION.value and not is_converted(code_fragment.operation): args.update({"input_shape": self.input_shape, "output_shape": self.output_shape}) if self._node_type == NodeType.OPERATION.value: - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" for k, v in args.items()]) - ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct(ipt_args_in_construct, - settings) + ipt_args_settings_in_construct = self._generate_ipt_args_settings_in_construct( + ipt_args_in_construct, settings) else: # When it's type is module, class or func, # it's not necessary to replace var. - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + expr = ", ".join([f"{k.replace(f'_{variable_name}', '')}={v}" for k, v in args.items()]) ipt_args_settings_in_construct = ipt_args_in_construct - declare = f"self.{self._variable_name} = {operator}({expr})" - call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" - - return declare, call - - @staticmethod - def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): - """ - Generate input with args and settings in construct. - - Args: - ipt_args_in_construct(str): input args in construct. - settings(dict): settings in operator. - - """ - if settings.get('input_type'): - input_type = settings['input_type'] - if input_type == InputType.TENSOR.value: - ipt_args_settings_in_construct = ipt_args_in_construct - elif input_type == InputType.LIST.value: - ipt_args_settings_in_construct = f"({ipt_args_in_construct})" - else: - raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") - else: - ipt_args_settings_in_construct = ipt_args_in_construct + if SEPARATOR_IN_ONNX_OP in operator: + operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") - if settings.get('values'): - settings_value = settings['values'] - if settings_value: - settings_in_construct = ', '.join([f"{setting_val}" for _, setting_val in settings_value.items()]) - ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) + declare = f"self.{variable_name} = {operator}({expr})" + call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" - return ipt_args_settings_in_construct + return declare, call def to_ir(self): """ @@ -288,62 +210,3 @@ class PyTorchGraphNode(GraphNode): var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( RIGHT_BUCKET, "") return var - - def param_transform(self, mapper: Mapper): - """ - Transform torch params into mindspore. - - Args: - mapper (Mapper): Mapper of params. - - """ - if self._node_type != NodeType.OPERATION.value: - args = deepcopy(self._args_in_code) - self._args_in_code = dict() - for arg, value in args.items(): - self._args_in_code[self._get_arg_name(arg)] = value - return None, None, None - - if not self.transformed: - _, _, _ = super(PyTorchGraphNode, self).param_transform(mapper) - - for arg, value in self._params_in_ms.items(): - self._args_in_code[self._get_arg_name(arg)] = value - - for arg, value in self._settings_in_ms.items(): - self._settings_in_code[arg] = value - - self.transformed = True - - return self._op_in_ms, self._params_in_ms, self._settings_in_ms - - def froze_node_type_and_module_name(self, node_type, module_name): - """ - Froze node type and module name. - - After node_type is frozen, then the `module_name` - will be affected when `node_type` is `class`. - Thus, this line must be placed before `nd_inst.data.module_name`. - - Args: - module_name: Modified module name. - node_type (str): Node type, class of func. - - """ - if not self._type_frozen: - self._node_type = node_type - self._type_frozen = True - - if not self._module_name_frozen: - self._froze_module_name(module_name) - - def convert_successful(self): - """ - Whether convert successfully. - - Returns: - bool, true or false. - """ - if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: - return True - return False diff --git a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py b/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py index 7f5987fb..45775f86 100644 --- a/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py +++ b/tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py @@ -42,7 +42,7 @@ class TestHierarchicalTree: get_raw_params.return_value = [] tree = HierarchicalTree() pt_node = PyTorchGraphNode() - tree.insert(pt_node, 'ResNet', (1, 3, 224, 224), (1, 64, 112, 112)) + tree.insert(pt_node, 'ResNet') assert tree.root == 'ResNet' def test_remove(self): diff --git a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py index 8f8d21d3..1dd6aa51 100644 --- a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py +++ b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py @@ -17,11 +17,13 @@ import numpy as np import pytest from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting from tests.utils import mindspore class TestMappers: """Test Mappers.""" + @pytest.mark.parametrize('params', [{ 'input': {'op_name': 'onnx::Conv', 'params': {'dilations': [1, 1], @@ -38,7 +40,7 @@ class TestMappers: 'pad_mode': '\"pad\"', 'dilation': (1, 1), 'group': 1}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Conv', 'params': {'dilations': [1, 1], @@ -55,7 +57,7 @@ class TestMappers: 'pad_mode': '\"valid\"', 'dilation': (1, 1), 'group': 1}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Gemm', 'params': dict(), @@ -65,7 +67,7 @@ class TestMappers: 'converted_params': {'in_channels': 3, 'out_channels': 10, 'has_bias': True}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::BatchNormalization', 'params': {'epsilon': 1e-5, @@ -76,14 +78,14 @@ class TestMappers: 'converted_params': {'num_features': 6, 'eps': 1e-5, 'momentum': 0.9}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Relu', 'params': dict(), 'weights': dict()}, 'expected_output': {'converter_name': 'nn.ReLU', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::MaxPool', 'params': {'kernel_shape': [3, 3], @@ -94,7 +96,7 @@ class TestMappers: 'converted_params': {'kernel_size': (3, 3), 'stride': (2, 2), 'pad_mode': '"same"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::AveragePool', 'params': {'kernel_shape': [3, 3], @@ -105,7 +107,7 @@ class TestMappers: 'converted_params': {'kernel_size': (3, 3), 'stride': (2, 2), 'pad_mode': '"same"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::GlobalAveragePool', 'params': {'input_shape': (1, 3, 10, 10), @@ -113,21 +115,21 @@ class TestMappers: 'weights': ''}, 'expected_output': {'converter_name': 'nn.AvgPool2d', 'converted_params': {'kernel_size': (10, 10)}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Flatten', 'params': dict(), 'weights': dict()}, 'expected_output': {'converter_name': 'nn.Flatten', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Add', 'params': dict(), 'weights': dict()}, 'expected_output': {'converter_name': 'P.TensorAdd', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Pad', 'params': {'pads': [0, 1, 2, 3], @@ -137,7 +139,7 @@ class TestMappers: 'expected_output': {'converter_name': 'nn.Pad', 'converted_params': {'paddings': ((0, 2), (1, 3)), 'mode': '\"CONSTANT\"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Pad', 'params': {'pads': [0, 1, 2, 3], @@ -146,7 +148,7 @@ class TestMappers: 'expected_output': {'converter_name': 'nn.Pad', 'converted_params': {'paddings': ((0, 2), (1, 3)), 'mode': '\"REFLECT\"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Pad', 'params': {'pads': [0, 1, 2, 3], @@ -156,7 +158,7 @@ class TestMappers: 'expected_output': {'converter_name': 'nn.Pad', 'converted_params': {'paddings': ((0, 2), (1, 3)), 'mode': '{UNSUPPORTED: value is NOT 0}\"CONSTANT\"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Pad', 'params': {'pads': [0, 1, 2, 3], @@ -165,7 +167,7 @@ class TestMappers: 'expected_output': {'converter_name': 'nn.Pad', 'converted_params': {'paddings': ((0, 2), (1, 3)), 'mode': '{UNSUPPORTED: \"edge\"}\"UNKNOWN\"'}, - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::ReduceMean', 'params': {'keepdims': 0, @@ -196,14 +198,14 @@ class TestMappers: 'weights': dict()}, 'expected_output': {'converter_name': 'nn.ReLU6', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Clip', 'params': dict(), 'weights': dict()}, 'expected_output': {'converter_name': 'nn.ReLU', 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }, { 'input': {'op_name': 'onnx::Clip', 'params': {'max': 3, @@ -211,13 +213,13 @@ class TestMappers: 'weights': dict()}, 'expected_output': {'converter_name': None, 'converted_params': dict(), - 'converted_settings': dict()} + 'converted_settings': Setting()} }]) def test_mapper(self, params): """Test mapper function.""" mapper = ONNXToMindSporeMapper() - converter_name, converted_params, converted_settings = \ + converter_name, converted_params, converted_settings, _ = \ mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) assert params['expected_output']['converter_name'] == converter_name assert params['expected_output']['converted_params'] == converted_params - assert params['expected_output']['converted_settings'] == converted_settings + assert isinstance(converted_settings, Setting)