Add ops and nn mappers Replace import method in onnx_utils Support multi args in statement generation Sub graph search path bug fix Add shape check in onnx_utilstags/v1.1.0
| @@ -13,14 +13,42 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Hierarchical tree module.""" | |||
| import re | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .hierarchical_tree import HierarchicalTree | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| __all__ = [ | |||
| "HierarchicalTreeFactory" | |||
| ] | |||
| def _tf_model_node_name_reformat(node: OnnxGraphNode, node_name): | |||
| """ | |||
| Rename the node name by combining scope name and its original name. | |||
| Args: | |||
| node (OnnxGraphNode): OnnxGraphNode instance. | |||
| node_name (str): node name saved in Graph. | |||
| Returns: | |||
| str, re-formatted node name. | |||
| """ | |||
| scope_name = node.scope_name | |||
| new_name = None | |||
| parent = "" | |||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | |||
| match = re.match(regex, scope_name) | |||
| parent = match.group("parent") | |||
| node_name = '$' + node_name + '$' | |||
| if scope_name: | |||
| new_name = parent + node_name | |||
| if new_name: | |||
| return new_name | |||
| return node_name | |||
| class HierarchicalTreeFactory: | |||
| """Hierarchical tree factory.""" | |||
| @@ -36,6 +64,7 @@ class HierarchicalTreeFactory: | |||
| HierarchicalTree, tree. | |||
| """ | |||
| tree = HierarchicalTree() | |||
| node_scope_name = dict() | |||
| for _, node_name in enumerate(graph.nodes_in_topological_order): | |||
| node_inst = graph.get_node(node_name) | |||
| node_input = graph.get_input_shape(node_name) | |||
| @@ -44,6 +73,13 @@ class HierarchicalTreeFactory: | |||
| err_msg = f"This model is not supported now. " \ | |||
| f"Cannot find {node_name}'s input shape." | |||
| log.error(err_msg) | |||
| if isinstance(node_inst, OnnxGraphNode): | |||
| node_name_with_scope = _tf_model_node_name_reformat( | |||
| node_inst, node_name) | |||
| node_scope_name[node_name] = node_name_with_scope | |||
| node_name = node_name_with_scope | |||
| tree.insert(node_inst, node_name, node_input, node_output) | |||
| if node_scope_name: | |||
| return tree, node_scope_name | |||
| return tree | |||
| @@ -27,6 +27,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||
| from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | |||
| from ..mapper.base import Mapper | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||
| from ..constant import NodeType | |||
| @@ -59,6 +60,8 @@ class HierarchicalTree(Tree): | |||
| # Manage variable name in a module. | |||
| self._vars_mgr_in_module = dict() | |||
| self._module_vars = dict() | |||
| # scope name mapping record for easy node searching | |||
| self._scope_name_map = dict() | |||
| @property | |||
| def tree_identifier(self): | |||
| @@ -70,13 +73,23 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| return self.identifier | |||
| def insert(self, node: PyTorchGraphNode, node_name: str, input_shape, output_shape): | |||
| def get_node(self, nid): | |||
| """Override get_node method to support tf ver. generated scope.""" | |||
| if nid is None or not self.contains(nid): | |||
| if self._scope_name_map and nid in self._scope_name_map: | |||
| nid = self._scope_name_map.get(nid) | |||
| else: | |||
| return None | |||
| return self._nodes[nid] | |||
| def insert(self, node: Union[PyTorchGraphNode, OnnxGraphNode], | |||
| node_name: str, input_shape, output_shape): | |||
| """ | |||
| Insert node into hierarchical tree. | |||
| Args: | |||
| node_name (str): Node name. | |||
| node (PyTorchGraphNode): Node to be inserted. | |||
| node (Union[PyTorchGraphNode, OnnxGraphNode]): Node to be inserted. | |||
| output_shape (tuple): Output tensor shape. | |||
| input_shape (tuple): Input tensor shape. | |||
| @@ -102,7 +115,12 @@ class HierarchicalTree(Tree): | |||
| if not self.contains(identifier): | |||
| # Insert node into tree. | |||
| tgt_node = node if idx == len(scopes) - 1 else PyTorchGraphNode() | |||
| if isinstance(node, OnnxGraphNode): | |||
| tgt_node = node if idx == len( | |||
| scopes) - 1 else OnnxGraphNode() | |||
| else: | |||
| tgt_node = node if idx == len( | |||
| scopes) - 1 else PyTorchGraphNode() | |||
| tgt_node.successor_nodes = node.successor_nodes | |||
| tgt_node.precursor_nodes = node.precursor_nodes | |||
| tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 | |||
| @@ -154,7 +172,8 @@ class HierarchicalTree(Tree): | |||
| def save_source_files(self, out_folder: str, mapper: Mapper, | |||
| model_name: str, | |||
| report_folder: str = None) -> NoReturn: | |||
| report_folder: str = None, | |||
| scope_name_map: dict = None) -> NoReturn: | |||
| """ | |||
| Save source codes to target folder. | |||
| @@ -165,6 +184,8 @@ class HierarchicalTree(Tree): | |||
| out_folder (str): Output folder. | |||
| """ | |||
| if scope_name_map: | |||
| self._scope_name_map = scope_name_map | |||
| try: | |||
| self._adjust_structure() | |||
| code_fragments = self._generate_codes(mapper) | |||
| @@ -217,7 +238,8 @@ class HierarchicalTree(Tree): | |||
| Node, node. | |||
| """ | |||
| if module_key in self._merged_module_args: | |||
| node = self._clear_unused_args(node, self._merged_module_args[module_key]) | |||
| node = self._clear_unused_args( | |||
| node, self._merged_module_args[module_key]) | |||
| else: | |||
| node.data.clear_args_of_declaration() | |||
| return node | |||
| @@ -341,12 +363,21 @@ class HierarchicalTree(Tree): | |||
| nd_inst = self._preprocess_node_args(nd_inst, module_key) | |||
| # 4. Post-process child node args. | |||
| for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): | |||
| self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) | |||
| self._postprocess_node_args( | |||
| self.get_node(scsr_nd_name), module_key) | |||
| # 5. Generate code. | |||
| snippets.add(func(nd_inst, nd_inst.data.module_name, module_key)) | |||
| snippets.add( | |||
| func(nd_inst, nd_inst.data.module_name, module_key)) | |||
| code_blocks.extend(snippets) | |||
| if self._scope_name_map: # from tf. conversion | |||
| c_blocks = [] | |||
| for s in code_blocks: | |||
| s = s.replace('$', '') | |||
| c_blocks.append(s) | |||
| code_blocks = c_blocks | |||
| formatted_code, _ = FormatCode("".join(code_blocks), | |||
| style_config=CodeFormatConfig.PEP8.value) | |||
| report_generator = ReportGenerator() | |||
| @@ -469,8 +500,16 @@ class HierarchicalTree(Tree): | |||
| # Generate code statement. | |||
| init, construct = self._generate_stat(nd_inst, node, idx) | |||
| construct_block.append(construct) | |||
| init_block.append(init) | |||
| # support multiple construct and init block returns: | |||
| if isinstance(construct, list): | |||
| construct_block += construct | |||
| else: | |||
| construct_block.append(construct) | |||
| if isinstance(init, list): | |||
| init_block += init | |||
| else: | |||
| init_block.append(init) | |||
| class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \ | |||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| @@ -507,7 +546,8 @@ class HierarchicalTree(Tree): | |||
| if idx != 0: | |||
| # Get previous node output variable name. | |||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||
| ipt_args_in_construct = self._get_previous_opt_var( | |||
| cur_nd_inst, pre_nd_inst) | |||
| if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | |||
| # Set opt variable name. | |||
| opt_arg_in_construct = cur_nd_inst.data.opt_var_name | |||
| @@ -652,7 +692,8 @@ class HierarchicalTree(Tree): | |||
| nd_inst.data.variable_name = self._module_vars[module_key][idx] | |||
| else: | |||
| variable_name = nd_inst.data.op_name or nd_inst.data.module_name | |||
| variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) | |||
| variable_name = self._vars_mgr_in_module[module_key].get_name( | |||
| variable_name) | |||
| nd_inst.data.variable_name = variable_name | |||
| # Generation of params must behind variable assigment. | |||
| @@ -662,7 +703,8 @@ class HierarchicalTree(Tree): | |||
| module_settings.update(nd_inst.data.settings_in_code) | |||
| if not created: | |||
| self._module_vars[module_key].append(nd_inst.data.variable_name) | |||
| self._module_vars[module_key].append( | |||
| nd_inst.data.variable_name) | |||
| node.data.args_in_code = module_args | |||
| @@ -727,5 +769,8 @@ class HierarchicalTree(Tree): | |||
| Returns: | |||
| str, imported module. | |||
| """ | |||
| return f"from mindspore import nn{NEW_LINE}" \ | |||
| return f"import numpy as np{NEW_LINE}" \ | |||
| f"import mindspore{NEW_LINE}" \ | |||
| f"from mindspore import nn{NEW_LINE}" \ | |||
| f"from mindspore import Tensor{NEW_LINE}" \ | |||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||
| @@ -104,9 +104,11 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| return None, dict(), dict() | |||
| try: | |||
| converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | |||
| converter_name = op_name_converter( | |||
| params=params, weights=weights, op_name=op_name) | |||
| converted_params = params_converter(params=params, weights=weights) | |||
| converted_weights = weights_converter(weights=weights) if weights else dict() | |||
| converted_weights = weights_converter( | |||
| weights=weights) if weights else dict() | |||
| converted_params.update(converted_weights) | |||
| converted_settings = settings_converter(params=params) | |||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | |||
| @@ -28,9 +28,9 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| return { | |||
| 'num_features': params['output_shape'][1], | |||
| 'eps': params['epsilon'], | |||
| 'momentum': params['momentum'] | |||
| 'num_features': params.get('output_shape')[1], | |||
| 'eps': params.get('epsilon', 1e-5), | |||
| 'momentum': params.get('momentum', 0.9) | |||
| } | |||
| @staticmethod | |||
| @@ -13,24 +13,33 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import re | |||
| import numpy as np | |||
| from ...base import ONNXToMindSporeMapper | |||
| class ConvMapper(ONNXToMindSporeMapper): | |||
| """Conv2d mapper.""" | |||
| def _convert_padding(**kwargs): | |||
| """Convert padding.""" | |||
| params = kwargs['params'] | |||
| if not params.get('pads'): | |||
| return '\"same\"', 0 | |||
| if sum(params['pads']) == 0: | |||
| return '\"valid\"', 0 | |||
| pads_onnx = params['pads'] | |||
| half_index = len(pads_onnx) // 2 | |||
| padding = [] | |||
| for num_begin, num_end in zip(pads_onnx[:half_index], pads_onnx[half_index:]): | |||
| padding += [num_begin, num_end] | |||
| return '\"pad\"', tuple(padding) | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| weight = kwargs['weights']['weight'].numpy() | |||
| dim = weight.ndim - 2 | |||
| return f"nn.Conv{dim}d" | |||
| class ConvMapper(ONNXToMindSporeMapper): | |||
| """Conv2d mapper.""" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| def convert_params_torch(**kwargs): | |||
| """Convert params from PyTorch to MindSpore""" | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| weight = weights['weight'].numpy() | |||
| weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | |||
| if isinstance(params['dilations'], list): | |||
| @@ -49,7 +58,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| kernel_size = kernel_size[0] | |||
| else: | |||
| kernel_size = tuple(kernel_size) | |||
| pad_mode, padding = ConvMapper._convert_padding(params=params) | |||
| pad_mode, padding = _convert_padding(params=params) | |||
| return { | |||
| 'in_channels': in_channels, | |||
| 'out_channels': out_channels, | |||
| @@ -61,21 +70,74 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| 'group': params['group']} | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| def convert_params_tf(**kwargs): | |||
| """Convert params from Tensorflow to MindSpore""" | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| # regex to find Conv weight | |||
| regex = r".+\/Conv2D\/ReadVariableOp:0$" | |||
| weight = None | |||
| for w_name, w in weights.items(): | |||
| if re.match(regex, w_name): | |||
| weight = w | |||
| break | |||
| if weight is None: | |||
| raise ValueError("Conv. Mapper cannot get the weight.") | |||
| # tmp tf translated ver. mapping | |||
| if isinstance(params.get('dilations'), list): | |||
| dilation = tuple(params.get('dilations')) | |||
| else: | |||
| dilation = params.get('dilations') | |||
| if isinstance(params.get('strides'), list): | |||
| stride = tuple(params.get('strides')) | |||
| else: | |||
| stride = params.get('strides') | |||
| kernel_size = params.get('kernel_shape') | |||
| in_channels = weight.shape[1] | |||
| out_channels = weight.shape[0] | |||
| if len(kernel_size) == 1: | |||
| kernel_size = kernel_size[0] | |||
| else: | |||
| kernel_size = tuple(kernel_size) | |||
| pad_mode, padding = _convert_padding(params=params) | |||
| return { | |||
| 'in_channels': in_channels, | |||
| 'out_channels': out_channels, | |||
| 'kernel_size': kernel_size, | |||
| 'stride': stride, | |||
| 'padding': padding, | |||
| 'pad_mode': pad_mode, | |||
| 'dilation': dilation, | |||
| 'group': params.get('group', 1)} | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| if not kwargs['weights'].get('weight'): # is from tf | |||
| kernel_size = kwargs['params'].get('kernel_shape') | |||
| dim = len(kernel_size) | |||
| return f"nn.Conv{dim}d" | |||
| weight = kwargs['weights']['weight'].numpy() | |||
| dim = weight.ndim - 2 | |||
| return f"nn.Conv{dim}d" | |||
| @staticmethod | |||
| def _convert_padding(**kwargs): | |||
| """Convert padding.""" | |||
| def _convert_params(**kwargs): | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| if sum(params['pads']) == 0: | |||
| return '\"valid\"', 0 | |||
| pads_onnx = params['pads'] | |||
| half_index = len(pads_onnx) // 2 | |||
| padding = [] | |||
| for num_begin, num_end in zip(pads_onnx[:half_index], pads_onnx[half_index:]): | |||
| padding += [num_begin, num_end] | |||
| return '\"pad\"', tuple(padding) | |||
| if not weights.get('weight'): # is from tf | |||
| return ConvMapper.convert_params_tf(params=params, weights=weights) | |||
| return ConvMapper.convert_params_torch(params=params, weights=weights) | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class MatMulMapper(ONNXToMindSporeMapper): | |||
| """MatMul mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.MatMul" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -16,6 +16,25 @@ | |||
| from ...base import ONNXToMindSporeMapper | |||
| def _padding_format_convert(padding: list): | |||
| """Convert Onnx padding format to Mindspore""" | |||
| num = len(padding) | |||
| if num % 2 != 0: | |||
| raise ValueError(f"Padding list should be even length but got {num}") | |||
| low = 0 | |||
| mid = num // 2 | |||
| lst = [] | |||
| ms_pad_front = low | |||
| ms_pad_back = mid | |||
| while ms_pad_front < mid and ms_pad_back < num: | |||
| lst.append((padding[ms_pad_front], padding[ms_pad_back])) | |||
| ms_pad_front += 1 | |||
| ms_pad_back += 1 | |||
| return tuple(lst) | |||
| class PadMapper(ONNXToMindSporeMapper): | |||
| """Pad mapper.""" | |||
| @@ -26,16 +45,24 @@ class PadMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| params = kwargs['params'] | |||
| if params['mode'] == 'constant': | |||
| mode = params.get('mode', 'constant') | |||
| if mode == 'constant' and params.get('value') is None: | |||
| if params.get('pads'): | |||
| pads_onnx = params.get('pads') | |||
| if isinstance(pads_onnx, list): | |||
| paddings = _padding_format_convert(pads_onnx) | |||
| return {'paddings': paddings, | |||
| 'mode': '\"CONSTANT\"'} | |||
| if mode == 'constant': | |||
| if params['value'] == 0: | |||
| mode = '\"CONSTANT\"' | |||
| else: | |||
| msg = "{UNSUPPORTED: value is NOT 0}\"CONSTANT\"" | |||
| mode = msg | |||
| elif params['mode'] == 'reflect': | |||
| elif mode == 'reflect': | |||
| mode = '\"REFLECT\"' | |||
| else: | |||
| msg = f"{{UNSUPPORTED: \"{params['mode']}\"}}\"UNKNOWN\"" | |||
| msg = f"{{UNSUPPORTED: \"{mode}\"}}\"UNKNOWN\"" | |||
| mode = msg | |||
| pads_onnx = params['pads'] | |||
| half_index = len(pads_onnx) // 2 | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class SoftmaxMapper(ONNXToMindSporeMapper): | |||
| """Softmax mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.Softmax" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| params = kwargs.get('params') | |||
| converted_params = {} | |||
| if params.get('axis'): | |||
| converted_params['axis'] = params.get('axis') | |||
| return converted_params | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return dict() | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class TransposeMapper(ONNXToMindSporeMapper): | |||
| """Transpose mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Transpose" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| converted_params = {} | |||
| params = kwargs.get('params') | |||
| perm = params.get('perm') | |||
| if perm and isinstance(perm, list): | |||
| perm = tuple(perm) | |||
| converted_params['input_perm'] = perm | |||
| return {'values': converted_params} | |||
| @@ -11,5 +11,8 @@ | |||
| "onnx::Pad": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pad_mapper.PadMapper", | |||
| "onnx::ReduceMean": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reduce_mean_mapper.ReduceMeanMapper", | |||
| "onnx::Concat": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.concat_mapper.ConcatMapper", | |||
| "onnx::Clip": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper" | |||
| "onnx::Clip": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper", | |||
| "onnx::Transpose": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.transpose_mapper.TransposeMapper", | |||
| "onnx::MatMul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.mat_mul_mapper.MatMulMapper", | |||
| "onnx::Softmax": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.softmax_mapper.SoftmaxMapper" | |||
| } | |||
| @@ -98,7 +98,7 @@ class MergedONNXNode(BaseNode): | |||
| """Define merged onnx node.""" | |||
| def __init__(self, name, module_name, ori_nodes): | |||
| super(MergedONNXNode, self).__init__(name, module_name) | |||
| super(MergedONNXNode, self).__init__(node_name=name, op_type=module_name) | |||
| self.nodes = ori_nodes | |||
| def get_name(self): | |||
| @@ -16,13 +16,17 @@ | |||
| from .base import Graph | |||
| from .pytorch_graph import PyTorchGraph | |||
| from .pytorch_graph_node import PyTorchGraphNode | |||
| from .onnx_graph import OnnxGraph | |||
| from .onnx_graph_node import OnnxGraphNode | |||
| class GraphFactory: | |||
| """Graph factory.""" | |||
| @classmethod | |||
| def init(cls, graph_path: str, sample_shape: tuple, checkpoint: str = None): | |||
| def init(cls, graph_path: str, | |||
| input_nodes: str, output_nodes: str, | |||
| sample_shape: tuple): | |||
| """ | |||
| Init an instance of graph. | |||
| @@ -34,8 +38,9 @@ class GraphFactory: | |||
| Returns: | |||
| Graph, graph instance. | |||
| """ | |||
| if checkpoint: | |||
| pass | |||
| if all([input_nodes, output_nodes]): | |||
| return OnnxGraph.load(model_path=graph_path, input_nodes=input_nodes, | |||
| output_nodes=output_nodes, sample_shape=sample_shape) | |||
| return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape) | |||
| @@ -26,7 +26,7 @@ class GraphParser(metaclass=abc.ABCMeta): | |||
| @classmethod | |||
| @abc.abstractmethod | |||
| def parse(cls, model_path: str): | |||
| def parse(cls, model_path: str, **kwargs): | |||
| """Parse graph into readable format.""" | |||
| @@ -54,18 +54,19 @@ class BaseGraph(metaclass=abc.ABCMeta): | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def load_graph(graph_path: str): | |||
| def load_graph(graph_path: str, **kwargs): | |||
| """Load graph file.""" | |||
| @classmethod | |||
| @abc.abstractmethod | |||
| def load(cls, model_path: str, sample_shape: tuple = None, | |||
| checkpoint: str = None): | |||
| checkpoint: str = None, **kwargs): | |||
| """Factory method to initialize an graph object.""" | |||
| def __new__(cls, *args, **kwargs): | |||
| """Control the create action of graph.""" | |||
| model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL) | |||
| model_param = args[0] if args else kwargs.get( | |||
| cls._REQUIRED_PARAM_OF_MODEL) | |||
| if not model_param: | |||
| error = ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` " | |||
| f"can not be None.") | |||
| @@ -229,12 +230,12 @@ class Graph(BaseGraph, abc.ABC): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def load_graph(graph_path: str): | |||
| def load_graph(graph_path: str, **kwargs): | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def load(cls, model_path: str, sample_shape: tuple = None, | |||
| checkpoint: str = None) -> BaseGraph: | |||
| checkpoint: str = None, **kwargs) -> BaseGraph: | |||
| """ | |||
| Load third party graph, metadata and checkpoint. | |||
| @@ -245,12 +246,19 @@ class Graph(BaseGraph, abc.ABC): | |||
| model_path (str): Graph or model file path. | |||
| sample_shape (tuple): Input shape of the model. | |||
| checkpoint (str): Checkpoint file path. | |||
| input_nodes (list[str]): list of input nodes' name | |||
| output_nodes (list[str]): list of output nodes' name | |||
| Returns: | |||
| cls, graph instance. | |||
| """ | |||
| src_graph = cls.load_graph(graph_path=model_path) | |||
| ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None | |||
| tf_input_nodes = kwargs.get('input_nodes') | |||
| tf_output_nodes = kwargs.get('output_nodes') | |||
| src_graph = cls.load_graph(graph_path=model_path, | |||
| input_nodes=tf_input_nodes, | |||
| output_nodes=tf_output_nodes) | |||
| ckpt = cls.load_checkpoint( | |||
| ckpt_path=checkpoint) if checkpoint else None | |||
| if ckpt is not None: | |||
| # Create an instance of TensorflowGraph. | |||
| @@ -258,7 +266,7 @@ class Graph(BaseGraph, abc.ABC): | |||
| checkpoint=ckpt) | |||
| # Create an instance of PyTorchGraph. | |||
| return cls(model=src_graph, sample_shape=sample_shape) | |||
| return cls(src_graph, sample_shape=sample_shape) | |||
| class GraphNode(abc.ABC): | |||
| @@ -22,7 +22,7 @@ class PyTorchGraphParser(GraphParser): | |||
| """Define pytorch graph parser.""" | |||
| @classmethod | |||
| def parse(cls, model_path: str): | |||
| def parse(cls, model_path: str, **kwargs): | |||
| """ | |||
| Parser pytorch graph. | |||
| @@ -61,3 +61,52 @@ class PyTorchGraphParser(GraphParser): | |||
| raise Exception(error_msg) | |||
| return model | |||
| class TFGraphParser(GraphParser): | |||
| """Define TF graph parser.""" | |||
| @classmethod | |||
| def parse(cls, model_path: str, **kwargs): | |||
| """ | |||
| Parse TF Computational Graph File (.pb) | |||
| Args: | |||
| model_path (str): Model file path. | |||
| Returns: | |||
| object, ONNX model. | |||
| """ | |||
| from .onnx_utils import convert_tf_graph_to_onnx | |||
| tf_input_nodes = kwargs.get('input_nodes') | |||
| tf_output_nodes = kwargs.get('output_nodes') | |||
| if not os.path.exists(model_path): | |||
| error = FileNotFoundError("`model_path` must be assigned with " | |||
| "an existed file path.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| try: | |||
| model = convert_tf_graph_to_onnx(model_path, | |||
| model_inputs=tf_input_nodes, | |||
| model_outputs=tf_output_nodes, | |||
| ) # need pass more args | |||
| except ModuleNotFoundError: | |||
| error_msg = \ | |||
| "Cannot find model scripts in system path, " \ | |||
| "set `--project_path` to the path of model scripts folder correctly." | |||
| error = ModuleNotFoundError(error_msg) | |||
| log.error(error_msg) | |||
| log.exception(error) | |||
| raise error | |||
| except Exception as e: | |||
| error_msg = "Error occurs in loading model, make sure model.pb correct." | |||
| log.error(error_msg) | |||
| log.exception(e) | |||
| raise Exception(error_msg) | |||
| return model | |||
| @@ -0,0 +1,207 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define ONNX graph.""" | |||
| from typing import Dict, NoReturn | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .onnx_graph_node import OnnxGraphNode | |||
| from .graph_parser import TFGraphParser | |||
| from .onnx_utils import OnnxDataLoader | |||
| NONE_SCOPE_OP = { | |||
| "onnx::Add": "Add", | |||
| "onnx::Flatten": "Flatten", | |||
| "onnx::Concat": "Concat", | |||
| "onnx::Squeeze": "Squeeze", | |||
| "onnx::Unsqueeze": "Unsqueeze", | |||
| } | |||
| def normalize_node_name(node): | |||
| """ | |||
| Rename the node name by removing :0 | |||
| Args: | |||
| node (Node, str): ONNX node instance or node name string. | |||
| Returns: | |||
| str, normalized node name. | |||
| """ | |||
| if isinstance(node, str): | |||
| return node.split(':')[0] | |||
| return node.name.split(':')[0] | |||
| class OnnxGraph(Graph): | |||
| """ | |||
| Define ONNX graph. | |||
| Args: | |||
| model (onnx.ModelProto): Onnx defined model proto. | |||
| sample_shape (tuple): Input shape of the model. | |||
| """ | |||
| def __init__(self, model, sample_shape: tuple = None): | |||
| super(OnnxGraph, self).__init__(model=model) | |||
| self.build(sample_shape) | |||
| def _extract_shape(self, shape): | |||
| """ | |||
| Extract shape from string-type shape. | |||
| Args: | |||
| shape (str): Shape value in string-type. | |||
| Returns: | |||
| list, shape. | |||
| """ | |||
| if "," not in shape: | |||
| return [] | |||
| shape_arr = [] | |||
| for s in shape.split(","): | |||
| s = s.strip() | |||
| if not s: | |||
| return [] | |||
| if ":" in s: | |||
| s = s.split(":")[0] | |||
| s = s.replace("!", "") | |||
| if not s.isdigit(): | |||
| return [] | |||
| shape_arr.append(int(s)) | |||
| return shape_arr | |||
| def _build_connection(self, src, tgt) -> NoReturn: | |||
| """ | |||
| Build connection between source node and target node. | |||
| Args: | |||
| src (str): Source node name. | |||
| tgt (str): Target node name. | |||
| """ | |||
| # If src and tgt are the same node, src not in node_collection or | |||
| # tgt not in node_collection, then skip this edge. | |||
| src = normalize_node_name(src) | |||
| tgt = normalize_node_name(tgt) | |||
| if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: | |||
| if src.split(':')[0] not in self._nodes_collection: | |||
| log.warning( | |||
| "Graph construct a self-loop node %s. Ignored.", src) | |||
| return | |||
| if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: | |||
| self._nodes_collection[src.split( | |||
| ':')[0]].successor_nodes.append(tgt) | |||
| if src not in self._nodes_collection[tgt].precursor_nodes: | |||
| self._nodes_collection[tgt.split( | |||
| ':')[0]].precursor_nodes.append(src) | |||
| def build(self, input_shape=None): | |||
| """ | |||
| Build graph tree. | |||
| Args: | |||
| input_shape (tuple): Input shape of model. Default: None | |||
| """ | |||
| model_data = OnnxDataLoader(self.model, graph_input_shape=input_shape) | |||
| from ..sub_graph_searcher import generate_scope_name | |||
| scope_name_list = generate_scope_name(model_data) | |||
| self._shape_dict = model_data.normalize_dict_key( | |||
| model_data.node_output_shape_dict) | |||
| for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()): | |||
| node_weight = {} | |||
| node.scope_name = scope_name_list[ind] | |||
| inputs = node.input_name_list | |||
| # check each input from node or tensors | |||
| for i in inputs: | |||
| if i in model_data.tensor_name_set: | |||
| tensor = model_data.tensors_dict[i] | |||
| t_name = tensor.name | |||
| t_value = tensor.to_array() | |||
| node_weight[t_name] = t_value | |||
| self._nodes_collection[node_name] = OnnxGraphNode( | |||
| node, node_weight) | |||
| self._nodes_record[node_name] = node_name | |||
| for node_input in node.input_name_list: | |||
| self._build_connection(node_input, node_name) | |||
| super(OnnxGraph, self).build(input_shape=input_shape) | |||
| self._collect_input_shape_of_each_node( | |||
| input_shape) # diff than pyTorch | |||
| def _collect_input_shape_of_each_node(self, input_shape): | |||
| """ | |||
| Collect input tensor shape of each node. | |||
| Args: | |||
| input_shape (tuple): Input shape. | |||
| """ | |||
| input_node = InputNode(input_shape) | |||
| input_node_name = "{}InputNode" | |||
| for node_name, node in self._nodes_collection.items(): | |||
| if node_name in self._input_nodes: | |||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | |||
| input_node.set_scope_name(node.scope_name) | |||
| node.precursor_nodes.insert(0, ipt_nd_name) | |||
| input_node.set_successor_nodes(node_name) | |||
| self._shape_dict[ipt_nd_name] = input_node.output_shape | |||
| ipt_shape = [] | |||
| for p_nd in node.precursor_nodes: | |||
| shp = self._shape_dict.get(p_nd) | |||
| ipt_shape.append(tuple(shp)) | |||
| self._input_shape[node_name] = ipt_shape[0] if len( | |||
| ipt_shape) == 1 else ipt_shape | |||
| def sub_graph_merging(self): | |||
| raise NotImplementedError() | |||
| @staticmethod | |||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||
| raise NotImplementedError() | |||
| @staticmethod | |||
| def load_metadata(**kwargs): | |||
| raise NotImplementedError() | |||
| @staticmethod | |||
| def load_graph(graph_path: str, **kwargs): | |||
| """ | |||
| Load graph. | |||
| Note: | |||
| The input/output nodes are optional for | |||
| tf saved model format. But required for .pb & .ckpt | |||
| Args: | |||
| graph_path (str): Graph path. | |||
| tf_input_nodes (str): input nodes of tf graph | |||
| tf_output_nodes (str): output nodes of tf graph | |||
| Returns: | |||
| object, ONNX model. | |||
| """ | |||
| tf_input_nodes = kwargs.get('input_nodes') | |||
| tf_output_nodes = kwargs.get('output_nodes') | |||
| onnx_model = TFGraphParser.parse(graph_path, | |||
| input_nodes=tf_input_nodes, | |||
| output_nodes=tf_output_nodes) | |||
| return onnx_model | |||
| @@ -0,0 +1,378 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define ONNX graph node.""" | |||
| from copy import deepcopy | |||
| from .base import GraphNode | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ | |||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP | |||
| from ..mapper.base import Mapper | |||
| class OnnxGraphNode(GraphNode): | |||
| """ | |||
| ONNX Graph Node. | |||
| Args: | |||
| node (OnnxNode): OnnxNode Object. | |||
| weight (dict): Dictionary records weight and bias. | |||
| """ | |||
| _type_frozen = False | |||
| _module_name_frozen = False | |||
| def __init__(self, node=None, weight=None): | |||
| super(OnnxGraphNode, self).__init__(node=node) | |||
| self._op_params = self._get_raw_params(node.raw_node) if node else None | |||
| self._op_name = "onnx::" + node.op_type if node else None | |||
| self._scope_name = node.scope_name if node else None | |||
| self._opt_var_name = None | |||
| self._variable_name = self._extract_var_name(self._scope_name) | |||
| self._module_name = None | |||
| self._weight = weight | |||
| def clear_args_of_declaration(self): | |||
| """Clear `self._args_in_code`.""" | |||
| self._args_in_code = dict() | |||
| def _get_arg_name(self, arg): | |||
| """ | |||
| Get arg name. | |||
| Args: | |||
| arg (str): Generate arg name. | |||
| Returns: | |||
| str, arg name in function or class declaration. | |||
| """ | |||
| return f"{arg}_{self._variable_name}" | |||
| @property | |||
| def hash_key(self): | |||
| """ | |||
| Return unique hash key of current node. | |||
| Returns: | |||
| str, hash key. | |||
| """ | |||
| if self._node_type not in {NodeType.CLASS.value, | |||
| NodeType.FUNC.value, | |||
| NodeType.MODULE.value}: | |||
| self._hash_key = self._op_name.lower() | |||
| return self._hash_key | |||
| @hash_key.setter | |||
| def hash_key(self, h): | |||
| """ | |||
| Setter of hash key. | |||
| Args: | |||
| h (str): Key. | |||
| """ | |||
| self._hash_key = h | |||
| @property | |||
| def variable_name(self): | |||
| """ | |||
| Variable name. | |||
| Returns: | |||
| str, variable name declared in init. | |||
| """ | |||
| return self._variable_name | |||
| @variable_name.setter | |||
| def variable_name(self, v): | |||
| """ | |||
| Setter of variable name. | |||
| Args: | |||
| v (str): Variable name. | |||
| """ | |||
| self._variable_name = v | |||
| @property | |||
| def module_name(self): | |||
| """ | |||
| Module name. | |||
| Returns: | |||
| str, module name. | |||
| """ | |||
| if not self._module_name_frozen: | |||
| module_name = self.tag | |||
| return module_name | |||
| return self._module_name | |||
| def _froze_module_name(self, m): | |||
| """ | |||
| Once module_name is set, then it's unchangeable. | |||
| Args: | |||
| m (str): Module name. | |||
| """ | |||
| if not self._module_name_frozen: | |||
| self._module_name = m | |||
| self._module_name_frozen = True | |||
| @property | |||
| def op_name(self): | |||
| """ | |||
| Op name in onnx. | |||
| Returns: | |||
| str, op name | |||
| """ | |||
| return self._op_name | |||
| @property | |||
| def real_name(self): | |||
| return | |||
| def add_input_and_output_shape(self, input_shape, output_shape): | |||
| """ | |||
| Add the node input shape. | |||
| Args: | |||
| output_shape (tuple): Output tensor shape. | |||
| input_shape (tuple): Input tensor shape. | |||
| """ | |||
| self._ipt_shape = input_shape | |||
| self._opt_shape = output_shape | |||
| def _add_tensor_args_to_code(self, op_name: str, t_identifier: str, declare, args): | |||
| """ | |||
| Add nn used tensors to args in init and construct blocks. | |||
| Args: | |||
| op_name (str): Add the tensor to args if the current node has this | |||
| op_name. | |||
| t_identifier (str): The unique strinf appeared in the target tensor | |||
| name. | |||
| declare_s (str): Declare statement generated in to_code(). | |||
| init_s (str): init statement generated in to_code(). | |||
| Returns: | |||
| declare_list list, multiple declare statements. | |||
| input_args list, multiple input args generated statements. | |||
| """ | |||
| if not self._op_name == op_name: | |||
| return declare, args | |||
| declare_list = [] | |||
| tensor = None | |||
| # find target tensor | |||
| for t_name, t_value in self._weight.items(): | |||
| if t_identifier in t_name: | |||
| tensor = t_value | |||
| break | |||
| if tensor is None: | |||
| return declare, args | |||
| declare_list.append(declare) | |||
| declare_t = f"self.{self._variable_name}_w = Tensor(" \ | |||
| f"np.random.uniform(0, 1, {str(tensor.shape)}), mindspore.float32)" | |||
| declare_list.append(declare_t) | |||
| args += f", self.{self._variable_name}_w" | |||
| return declare_list, args | |||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||
| """ | |||
| Generate statements. | |||
| Args: | |||
| ipt_args_in_construct (str): Args of input. | |||
| output_var (str): Output variable name in construct. | |||
| Returns: | |||
| Union[str, str], declare in init and call in construct. | |||
| """ | |||
| operator = self.op_in_ms or self.module_name | |||
| self._opt_var_name = output_var | |||
| args = self.args_in_code | |||
| settings = self.settings_in_code | |||
| if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | |||
| args.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| if self._node_type == NodeType.OPERATION.value: | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = \ | |||
| self._generate_ipt_args_settings_in_construct( | |||
| ipt_args_in_construct, | |||
| settings) | |||
| else: | |||
| # When it's type is module, class or func, | |||
| # it's not necessary to replace var. | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| declare = f"self.{self._variable_name} = {operator}({expr})" | |||
| # Extra Tensor generator for nn.MatMul | |||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | |||
| 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) | |||
| # Extra Tensor generator for onnx::BiasAdd | |||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | |||
| 'onnx::MatMul', 'BiasAdd', declare, ipt_args_settings_in_construct) | |||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | |||
| return declare, call | |||
| @staticmethod | |||
| def _generate_ipt_args_settings_in_construct(ipt_args_in_construct, settings): | |||
| """ | |||
| Generate input with args and settings in construct. | |||
| Args: | |||
| ipt_args_in_construct(str): Input args in construct. | |||
| settings(dict): Settings in operator. | |||
| Returns: | |||
| str, args of each node in generated construct statement. | |||
| """ | |||
| if settings.get('input_type'): | |||
| input_type = settings['input_type'] | |||
| if input_type == InputType.TENSOR.value: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| elif input_type == InputType.LIST.value: | |||
| ipt_args_settings_in_construct = f"({ipt_args_in_construct})" | |||
| else: | |||
| raise NodeInputTypeNotSupport( | |||
| f"Input type[{input_type}] is not supported now.") | |||
| else: | |||
| ipt_args_settings_in_construct = ipt_args_in_construct | |||
| if settings.get('values'): | |||
| settings_value = settings['values'] | |||
| if settings_value: | |||
| settings_in_construct = ', '.join( | |||
| [f"{setting_val}" for _, setting_val in settings_value.items()]) | |||
| ipt_args_settings_in_construct = ', '.join( | |||
| (ipt_args_settings_in_construct, settings_in_construct)) | |||
| return ipt_args_settings_in_construct | |||
| def to_ir(self): | |||
| """No need to implement for now.""" | |||
| raise NotImplementedError | |||
| def _get_raw_params(self, node): | |||
| """ | |||
| Get params in onnx. | |||
| Note: parameters are attributes in node. | |||
| Args: | |||
| node (onnx.NodeProto): Onnx defined node proto. | |||
| Returns: | |||
| dict, raw params. | |||
| """ | |||
| import onnx | |||
| raw_params = dict() | |||
| if not node: | |||
| return raw_params | |||
| for attribute in node.attribute: | |||
| name = attribute.name | |||
| value = onnx.helper.get_attribute_value(attribute) | |||
| raw_params[name] = value | |||
| return raw_params | |||
| def replace_with_arg(self, src_arg, tgt_arg): | |||
| """ | |||
| Replace actual parameter with formal parameter. | |||
| Args: | |||
| src_arg (str): Original arg name. | |||
| tgt_arg (str): Target arg name. | |||
| """ | |||
| self._args_in_code[src_arg] = tgt_arg | |||
| @staticmethod | |||
| def _extract_var_name(scope_name: str): | |||
| """Extract variable name from scope name.""" | |||
| if not scope_name: | |||
| return None | |||
| var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower() | |||
| var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | |||
| RIGHT_BUCKET, "") | |||
| return var | |||
| def param_transform(self, mapper: Mapper): | |||
| """ | |||
| Transform torch params into mindspore. | |||
| Args: | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| if self._node_type != NodeType.OPERATION.value: | |||
| args = deepcopy(self._args_in_code) | |||
| self._args_in_code = dict() | |||
| for arg, value in args.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| return None, None | |||
| if not self.transformed: | |||
| _, _, _ = super(OnnxGraphNode, self).param_transform(mapper) | |||
| for arg, value in self._params_in_ms.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| for arg, value in self._settings_in_ms.items(): | |||
| self._settings_in_code[arg] = value | |||
| self.transformed = True | |||
| return self._op_in_ms, self._params_in_ms, self._settings_in_ms | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| """ | |||
| Froze node type and module name. | |||
| After node_type is frozen, then the `module_name` | |||
| will be affected when `node_type` is `class`. | |||
| Thus, this line must be placed before `nd_inst.data.module_name`. | |||
| Args: | |||
| module_name: Modified module name. | |||
| node_type (str): Node type, class of func. | |||
| """ | |||
| if not self._type_frozen: | |||
| self._node_type = node_type | |||
| self._type_frozen = True | |||
| if not self._module_name_frozen: | |||
| self._froze_module_name(module_name) | |||
| def convert_successful(self): | |||
| """ | |||
| Whether convert successfully. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: | |||
| return True | |||
| return False | |||
| @@ -15,7 +15,9 @@ | |||
| """Define ONNX related operations.""" | |||
| import re | |||
| import abc | |||
| from importlib import import_module | |||
| from collections import OrderedDict | |||
| from typing import Union | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| @@ -39,10 +41,15 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None | |||
| Returns: | |||
| onnx.ModelProto, onnx defined model proto. | |||
| """ | |||
| import tensorflow as tf | |||
| from tf2onnx.tfonnx import process_tf_graph | |||
| from tf2onnx import constants, utils, optimizer | |||
| from tf2onnx import tf_loader | |||
| tf = import_module('tensorflow') | |||
| tf2onnx = import_module("tf2onnx") | |||
| tfonnx = getattr(tf2onnx, "tfonnx") | |||
| process_tf_graph = getattr(tfonnx, "process_tf_graph") | |||
| constants = getattr(tf2onnx, "constants") | |||
| utils = getattr(tf2onnx, "utils") | |||
| optimizer = getattr(tf2onnx, "optimizer") | |||
| tf_loader = getattr(tf2onnx, "tf_loader") | |||
| target = ",".join(constants.DEFAULT_TARGET) | |||
| shape_override = None | |||
| @@ -80,13 +87,9 @@ class OnnxTensor: | |||
| """ | |||
| Define Onnx Tensor structure for convenience. | |||
| Note: | |||
| parameter from_nodes and to_nodes. | |||
| Args: | |||
| raw_tensor (onnx.TensorProto): onnx.TensorProto instance. | |||
| """ | |||
| import onnx | |||
| def __init__(self, raw_tensor): | |||
| self.raw_tensor = raw_tensor | |||
| @@ -97,7 +100,8 @@ class OnnxTensor: | |||
| self.to_nodes = [] | |||
| def to_array(self): | |||
| """Convert binary data to np.array""" | |||
| onnx = import_module("onnx") | |||
| # Convert binary data to np.array | |||
| return onnx.numpy_helper.to_array(self.raw_tensor) | |||
| @@ -136,7 +140,6 @@ class ParamsAttribute: | |||
| for attribute in attrs: | |||
| self.attribute_name_list.append(attribute.name) | |||
| type_num = attribute.type | |||
| # get attribute value by determining its type | |||
| # Can Convert to np.array if needed | |||
| if type_num == ONNX_TYPE_INTS: | |||
| @@ -219,8 +222,8 @@ class OnnxNode(BaseNode): | |||
| self.raw_node = raw_node | |||
| self.params = ParamsAttribute(raw_node.attribute, raw_node) | |||
| self.scope_name = None | |||
| self.input_name_list = raw_node.input | |||
| self.output_name_list = raw_node.output | |||
| self.input_name_list = getattr(raw_node, 'input') | |||
| self.output_name_list = getattr(raw_node, 'output') | |||
| class OnnxDataLoader: | |||
| @@ -238,11 +241,11 @@ class OnnxDataLoader: | |||
| Default: True | |||
| """ | |||
| def __init__(self, onnx_model, infer_shape=True): | |||
| def __init__(self, onnx_model, graph_input_shape: Union[tuple, list] = None, infer_shape=True): | |||
| self.model = onnx_model | |||
| self.graph = onnx_model.graph | |||
| self.nodes = onnx_model.graph.node | |||
| self.graph_input_shape = graph_input_shape | |||
| # args for init | |||
| self._is_infer_shape = infer_shape | |||
| @@ -251,9 +254,7 @@ class OnnxDataLoader: | |||
| self.nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE | |||
| self.tensors_dict = {} # {tensor_name: OnnxTensor} | |||
| self.weight_dict = {} # {tensor_name: OnnxTensor} NOT USED | |||
| self.bias_dict = {} # {tensor_name: OnnxTensor} NOT USED | |||
| # {node_name : (type, dim)} NO INPUT & OUTPUT NODE! | |||
| # {node_name : (type, dim)} NO INPUT & OUTPUT NODE | |||
| self.value_info_dict = {} | |||
| self.tensor_name_set = set() # [str] | |||
| @@ -265,6 +266,20 @@ class OnnxDataLoader: | |||
| def _check_initialization(self): | |||
| """Define conditions checked before init.""" | |||
| if all([self.model, self.graph, self.nodes]): | |||
| if self.graph_input_shape is None: # do not check | |||
| return True | |||
| onnx = import_module("onnx") | |||
| # check input shape eligible | |||
| input_node = getattr(self.graph, 'input')[0] | |||
| type_str = onnx.helper.printable_type(input_node.type) | |||
| regex = r".*(unk.+)x(?P<h>\d+)x(?P<w>\d+)x(?P<c>\d+)" | |||
| match = re.match(regex, type_str) | |||
| h = int(match.group('h')) | |||
| w = int(match.group('w')) | |||
| c = int(match.group('c')) | |||
| if [h, w, c] != list(self.graph_input_shape)[1:4]: | |||
| raise ValueError( | |||
| f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||
| return True | |||
| return False | |||
| @@ -276,12 +291,12 @@ class OnnxDataLoader: | |||
| The method will be replaced by self-implemented | |||
| in future development. | |||
| """ | |||
| import onnx | |||
| onnx = import_module("onnx") | |||
| self.inferred_model = onnx.shape_inference.infer_shapes(self.model) | |||
| def _parse_value_info(self): # no input node & output node | |||
| """Parse onnx defined value_info class attributes.""" | |||
| import onnx | |||
| """Parse onnx defined value_info class attribtues""" | |||
| onnx = import_module("onnx") | |||
| def _parse_value_info_re(i): | |||
| """ | |||
| @@ -341,7 +356,11 @@ class OnnxDataLoader: | |||
| # replace unknown shape by '-1' | |||
| for s in shape_list: | |||
| if 'unk' in s: | |||
| s = '-1' | |||
| if self.graph_input_shape is not None: | |||
| s = self.graph_input_shape[0] | |||
| else: | |||
| s = '1' | |||
| # convert str to int | |||
| s = int(s) | |||
| lst.append(s) | |||