| @@ -78,6 +78,10 @@ TESTS*.xml | |||
| # vscode settings | |||
| .vscode | |||
| # OS files | |||
| *.DS_Store | |||
| package-lock.json | |||
| build/* | |||
| @@ -0,0 +1,18 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Graph based scripts converter definition.""" | |||
| from .framework import graph_based_converter | |||
| __all__ = ["graph_based_converter"] | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Constant definition.""" | |||
| from enum import Enum, unique | |||
| SEPARATOR_IN_ONNX_OP = "::" | |||
| SEPARATOR_IN_SCOPE = "/" | |||
| SEPARATOR_BTW_NAME_AND_ID = "_" | |||
| LINK_IN_SCOPE = "-" | |||
| LEFT_BUCKET = "[" | |||
| RIGHT_BUCKET = "]" | |||
| BLANK_SYM = " " | |||
| FIRST_LEVEL_INDENT = BLANK_SYM * 4 | |||
| SECOND_LEVEL_INDENT = BLANK_SYM * 8 | |||
| NEW_LINE = "\n" | |||
| @unique | |||
| class CodeFormatConfig(Enum): | |||
| PEP8 = "pep8" | |||
| @unique | |||
| class NodeType(Enum): | |||
| MODULE = "module" | |||
| OPERATION = "operation" | |||
| CLASS = "class" | |||
| FUNC = "func" | |||
| INPUT = "DataInput" | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Graph based scripts converter workflow.""" | |||
| import os | |||
| import argparse | |||
| from importlib.util import find_spec | |||
| import mindinsight | |||
| from .mapper import ONNXToMindSporeMapper | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| os.umask(permissions << 3 | permissions) | |||
| parser = argparse.ArgumentParser( | |||
| prog="MindConverter", | |||
| description="Graph based MindConverter CLI entry point (version: {})".format( | |||
| mindinsight.__version__) | |||
| ) | |||
| parser.add_argument("--graph", type=str, required=True, | |||
| help="Third party framework's graph path.") | |||
| parser.add_argument("--sample_shape", nargs='+', type=int, required=True, | |||
| help="Input shape of the model.") | |||
| parser.add_argument("--ckpt", type=str, required=False, | |||
| help="Third party framework's checkpoint path.") | |||
| parser.add_argument("--output", type=str, required=True, | |||
| help="Generated scripts output folder path.") | |||
| parser.add_argument("--report", type=str, required=False, | |||
| help="Generated reports output folder path.") | |||
| def torch_installation_validation(func): | |||
| """ | |||
| Validate args of func. | |||
| Args: | |||
| func (type): Function. | |||
| Returns: | |||
| type, inner function. | |||
| """ | |||
| def _f(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None, | |||
| checkpoint_path: str = None): | |||
| # Check whether pytorch is installed. | |||
| if not find_spec("torch"): | |||
| raise ModuleNotFoundError("PyTorch is required when using graph based " | |||
| "scripts converter, and PyTorch vision must " | |||
| "be consisted with model generation runtime.") | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| output_folder=output_folder, report_folder=report_folder, | |||
| checkpoint_path=checkpoint_path) | |||
| return _f | |||
| @torch_installation_validation | |||
| def graph_based_converter(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None, | |||
| checkpoint_path: str = None): | |||
| """ | |||
| Graph based scripts converter. | |||
| Args: | |||
| graph_path (str): Graph file path. | |||
| sample_shape (tuple): Input shape of the model. | |||
| output_folder (str): Output folder. | |||
| report_folder (str): Report output folder path. | |||
| checkpoint_path (str): Checkpoint file path. | |||
| """ | |||
| from .third_party_graph import GraphFactory | |||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||
| checkpoint=checkpoint_path) | |||
| hierarchical_tree = graph_obj.to_hierarchical_tree() | |||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||
| report_folder=report_folder) | |||
| if __name__ == '__main__': | |||
| args, _ = parser.parse_known_args() | |||
| graph_based_converter(graph_path=args.graph, | |||
| sample_shape=args.sample_shape, | |||
| output_folder=args.output, | |||
| report_folder=args.report, | |||
| checkpoint_path=args.ckpt) | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Hierarchical tree module.""" | |||
| from .hierarchical_tree import HierarchicalTree | |||
| __all__ = [ | |||
| "HierarchicalTree" | |||
| ] | |||
| @@ -0,0 +1,687 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define hierarchical tree.""" | |||
| import os | |||
| from copy import deepcopy | |||
| from typing import NoReturn, Union | |||
| from queue import Queue | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from treelib import Tree, Node | |||
| from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | |||
| from ..mapper.base import Mapper | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||
| from ..constant import NodeType | |||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||
| class HierarchicalTree(Tree): | |||
| """Define hierarchical tree.""" | |||
| _root_created = False | |||
| ROOT_LEVEL = 0 | |||
| def __init__(self): | |||
| super(HierarchicalTree, self).__init__() | |||
| self._hierarchical_order = dict() | |||
| # Manage mapping of unique key and module name. | |||
| self._merged_module = dict() | |||
| # Manage mapping of unique key and module args. | |||
| self._merged_module_args = dict() | |||
| # Record creation of module with unique key. | |||
| self._created_module = dict() | |||
| # Manage module name to used. | |||
| self._module_mgr = ModuleNameMgr() | |||
| # Manage variable name in a module. | |||
| self._args_mgr_in_module = dict() | |||
| self._module_vars = dict() | |||
| @property | |||
| def tree_identifier(self): | |||
| """ | |||
| Return identifier of tree. | |||
| Returns: | |||
| tree, id of tree. | |||
| """ | |||
| return self.identifier | |||
| def insert(self, node: PyTorchGraphNode, node_name: str, input_shape, output_shape): | |||
| """ | |||
| Insert node into hierarchical tree. | |||
| Args: | |||
| node_name (str): Node name. | |||
| node (PyTorchGraphNode): Node to be inserted. | |||
| output_shape (tuple): Output tensor shape. | |||
| input_shape (tuple): Input tensor shape. | |||
| """ | |||
| node.add_input_and_output_shape(input_shape, output_shape) | |||
| scopes = node_name.split(SEPARATOR_IN_SCOPE) | |||
| for idx, scope in enumerate(scopes): | |||
| parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | |||
| identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) | |||
| try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \ | |||
| if not parent else scope | |||
| if self.contains(try_parent): | |||
| # Whether current node existed. | |||
| parent = try_parent | |||
| if not parent and not self._root_created: | |||
| # If no root node, then create it and mark it. | |||
| parent = None | |||
| self._root_created = True | |||
| elif not parent and self._root_created: | |||
| # Already have root node, skip it. | |||
| continue | |||
| if not self.contains(identifier): | |||
| # Insert node into tree. | |||
| tgt_node = node if idx == len(scopes) - 1 else PyTorchGraphNode() | |||
| tgt_node.successor_nodes = node.successor_nodes | |||
| tgt_node.precursor_nodes = node.precursor_nodes | |||
| tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1 | |||
| else NodeType.MODULE).value | |||
| tgt_node.tag = scope.split(SEPARATOR_BTW_NAME_AND_ID)[0] | |||
| tgt_node.variable_name = self._get_var_name(identifier) | |||
| self.create_node( | |||
| tag=tgt_node.tag, | |||
| identifier=identifier, | |||
| parent=parent, | |||
| data=tgt_node | |||
| ) | |||
| def remove(self, node: Node, keep_sub=False): | |||
| """ | |||
| Remove node into hierarchical tree. | |||
| Args: | |||
| node (Node): Node to be removed. | |||
| keep_sub (bool): Whether keep sub-tree. | |||
| """ | |||
| if not keep_sub: | |||
| self.remove_node(node.identifier) | |||
| return | |||
| def shrink(self, node: Node): | |||
| """ | |||
| Shrink sub-tree into one node. | |||
| Args: | |||
| node (Node): List of nodes to be merged. | |||
| """ | |||
| node_name = node.identifier | |||
| parent_node = self[node.predecessor(self.tree_identifier)] | |||
| # Keep successors of parent. | |||
| brothers = deepcopy(parent_node.successors(self.tree_identifier)) | |||
| child = node.successors(self.tree_identifier)[0] | |||
| self.move_node(source=child, | |||
| destination=node.predecessor(self.tree_identifier)) | |||
| self.remove(node) | |||
| brothers[brothers.index(node_name)] = child | |||
| parent_node.set_successors(brothers, tree_id=self.tree_identifier) | |||
| def save_source_files(self, out_folder: str, mapper: Mapper, | |||
| report_folder: str = None) -> NoReturn: | |||
| """ | |||
| Save source codes to target folder. | |||
| Args: | |||
| report_folder (str): Report folder. | |||
| mapper (Mapper): Mapper of third party framework and mindspore. | |||
| out_folder (str): Output folder. | |||
| """ | |||
| self._adjust_structure() | |||
| code_fragments = self._generate_codes(mapper) | |||
| out_folder = os.path.abspath(out_folder) | |||
| if not report_folder: | |||
| report_folder = out_folder | |||
| else: | |||
| report_folder = os.path.abspath(report_folder) | |||
| if not os.path.exists(out_folder): | |||
| os.makedirs(out_folder) | |||
| if not os.path.exists(report_folder): | |||
| os.makedirs(report_folder) | |||
| for file_name in code_fragments: | |||
| code, report = code_fragments[file_name] | |||
| with open(os.path.join(os.path.abspath(out_folder), | |||
| f"{file_name}.py"), "w") as file: | |||
| file.write(code) | |||
| with open(os.path.join(report_folder, | |||
| f"report_of_{file_name}.txt"), "w") as rpt_f: | |||
| rpt_f.write(report) | |||
| def _preprocess_node_args(self, node, module_key): | |||
| """ | |||
| Remove unused args. | |||
| Args: | |||
| node (Node): Node instance. | |||
| module_key (str): Nodule key. | |||
| Returns: | |||
| Node, node. | |||
| """ | |||
| if module_key in self._merged_module_args: | |||
| node = self._clear_unused_args(node, self._merged_module_args[module_key]) | |||
| else: | |||
| node.data.clear_args_of_declaration() | |||
| return node | |||
| def _postprocess_node_args(self, node, precursor_module_key): | |||
| """ | |||
| Post process args in node. | |||
| Args: | |||
| node (Node): Node instance. | |||
| precursor_module_key (str): Parent node module name. | |||
| Returns: | |||
| Node, node. | |||
| """ | |||
| if node.data.node_type == NodeType.MODULE.value: | |||
| # If current node is class or function, then | |||
| # remove unused args in __init__. | |||
| cur_module_key = node.data.hash_key or self.hash_key(node) | |||
| if cur_module_key in self._merged_module_args: | |||
| node = self._clear_unused_args(node, | |||
| self._merged_module_args[cur_module_key]) | |||
| if precursor_module_key in self._merged_module_args: | |||
| # If parent node is in `_merged_module_args`, then | |||
| # replace current node args with arg name declared | |||
| # in _merged_module_args. | |||
| for arg in node.data.args_in_code.keys(): | |||
| if arg in self._merged_module_args[precursor_module_key]: | |||
| node.data.replace_with_arg(arg) | |||
| return node | |||
| @staticmethod | |||
| def _clear_unused_args(node, used_args): | |||
| """ | |||
| Clear unused args. | |||
| Args: | |||
| node (Node): Node. | |||
| used_args (list): Args list. | |||
| Returns: | |||
| Node, node instance. | |||
| """ | |||
| args_in_code = list(node.data.args_in_code.keys()) | |||
| for arg in args_in_code: | |||
| if arg not in used_args: | |||
| node.data.args_in_code.pop(arg) | |||
| return node | |||
| def _generate_codes(self, mapper): | |||
| """ | |||
| Generate code files. | |||
| - 1. Generate args. | |||
| - 2. Merge module. | |||
| - 3. Pre-process node args. | |||
| - 4. Post-process child node args. | |||
| - 5. Generate class/func code. | |||
| - 6. Merge code snippets. | |||
| Args: | |||
| mapper (Mapper): Mapper of third party operation and mindspore. | |||
| Returns: | |||
| Dict, codes. | |||
| """ | |||
| code_blocks = [self._get_imported_module()] | |||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | |||
| for depth in depths: | |||
| node_collection = self._hierarchical_order[depth] | |||
| for node_name in node_collection: | |||
| # Traverse nodes in topological order. | |||
| node = self.get_node(node_name) | |||
| # 1. Generate args for each node in this level. | |||
| if node.data.node_type == NodeType.MODULE.value: | |||
| self._create_module_args_and_vars(node, mapper) | |||
| # 2. Get nodes can be merged. | |||
| self._module_merging(node_collection) | |||
| snippets = set() | |||
| for node_name in node_collection: | |||
| nd_inst = self.get_node(node_name) | |||
| if nd_inst.data.node_type != NodeType.MODULE.value: | |||
| continue | |||
| # Generate hash key for node. | |||
| module_key = self.hash_key(nd_inst) | |||
| # Get code generation func. | |||
| func, node_type = self._fetch_func_and_type(nd_inst) | |||
| if module_key in self._created_module: | |||
| # If the module has already been created, | |||
| # then assign the created module name to current node, | |||
| # and delete unused args. | |||
| module_name = self._created_module[module_key] | |||
| nd_inst.data.froze_node_type_and_module_name(node_type, | |||
| module_name) | |||
| self._preprocess_node_args(nd_inst, module_key) | |||
| continue | |||
| module_name = nd_inst.data.module_name | |||
| if node_type == NodeType.CLASS.value: | |||
| module_name = f"{module_name[0].upper()}{module_name[1:]}" | |||
| # After node_type and module_name is frozen, | |||
| # then it's unchangeable. | |||
| module_name = self._module_mgr.get_name(module_name) | |||
| nd_inst.data.froze_node_type_and_module_name(node_type, | |||
| module_name) | |||
| # 3. Pre-process node args. | |||
| nd_inst = self._preprocess_node_args(nd_inst, module_key) | |||
| # 4. Post-process child node args. | |||
| for scsr_nd_name in nd_inst.successors(self.tree_identifier): | |||
| self._postprocess_node_args(self.get_node(scsr_nd_name), | |||
| module_key) | |||
| # 5. Generate code. | |||
| snippets.add(func(nd_inst, nd_inst.data.module_name, module_key)) | |||
| code_blocks.extend(snippets) | |||
| formatted_code, _ = FormatCode("".join(code_blocks), | |||
| style_config=CodeFormatConfig.PEP8.value) | |||
| return {"model": (formatted_code, "No report content.")} | |||
| def _fetch_func_and_type(self, node) -> Union[object, str]: | |||
| """ | |||
| Generate code snippet. | |||
| Args: | |||
| node (Node): Node. | |||
| Returns: | |||
| Union[object, str], code snippet func. | |||
| """ | |||
| def _is_func(): | |||
| """ | |||
| The correct thought is to check whether have more than one | |||
| path in this block. | |||
| """ | |||
| nonlocal node | |||
| tgt_type = {NodeType.MODULE.value, | |||
| NodeType.FUNC.value, NodeType.CLASS.value} | |||
| md_type_lst = [self.get_node(child).data.node_type | |||
| for child in node.successors(self.tree_identifier)] | |||
| diff_set = set(md_type_lst) - tgt_type | |||
| return not diff_set | |||
| if _is_func(): | |||
| return self._generate_func_snippet, NodeType.FUNC.value | |||
| return self._generate_class_snippet, NodeType.CLASS.value | |||
| def _generate_func_snippet(self, node, func_name, func_key): | |||
| """ | |||
| Generate function snippet. | |||
| Args: | |||
| node (Node): Node inst. | |||
| Returns: | |||
| str, code snippet. | |||
| """ | |||
| definition = "" | |||
| if func_key.lower() in self._merged_module_args and \ | |||
| self._merged_module_args[func_key.lower()]: | |||
| definition = ", ".join(self._merged_module_args[func_key.lower()]) | |||
| module_list = [] | |||
| for node_name in node.successors(self.tree_identifier): | |||
| c_nd = self.get_node(node_name) | |||
| operator = c_nd.data.op_in_ms or c_nd.data.module_name | |||
| if c_nd.data.node_type != NodeType.OPERATION.value: | |||
| hash_key = c_nd.data.hash_key or self.hash_key(c_nd) | |||
| if hash_key in self._created_module: | |||
| operator = self._created_module[hash_key] | |||
| args = c_nd.data.args_in_code | |||
| if c_nd.data.node_type == NodeType.OPERATION.value and \ | |||
| not c_nd.data.convert_successful(): | |||
| args.update({"input_shape": c_nd.data.input_shape, | |||
| "output_shape": c_nd.data.output_shape}) | |||
| # Generate code statement. | |||
| expr = ", ".join([f"{k}={v}" for k, v in args.items()]) | |||
| code_line = f"{operator}({expr})" | |||
| module_list.append(code_line) | |||
| body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list) | |||
| snippet = f"{FIRST_LEVEL_INDENT}module_list = [{NEW_LINE}" \ | |||
| f"{SECOND_LEVEL_INDENT}{body}{NEW_LINE}" \ | |||
| f"{FIRST_LEVEL_INDENT}]{NEW_LINE}" \ | |||
| f"{FIRST_LEVEL_INDENT}return nn.SequentialCell(*module_list)" | |||
| definition = f"def {func_name}({definition}):{NEW_LINE}" | |||
| # Mark the structure has been created. | |||
| self._created_module[func_key.lower()] = func_name | |||
| return f"{definition}{snippet}{NEW_LINE * 3}" | |||
| def _generate_class_snippet(self, node, class_name, class_key): | |||
| """ | |||
| Generate class-type code snippet. | |||
| Args: | |||
| node (Node): Node. | |||
| Returns: | |||
| str, code snippet. | |||
| """ | |||
| super_call = f"super({class_name}, self).__init__()" | |||
| if class_key.lower() in self._merged_module_args and \ | |||
| self._merged_module_args[class_key.lower()]: | |||
| args = f"{', '.join(self._merged_module_args[class_key.lower()])}" | |||
| class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \ | |||
| f"{args}):" \ | |||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | |||
| f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| else: | |||
| class_init = f"{FIRST_LEVEL_INDENT}def __init__(self):{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | |||
| f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| init_block = [] | |||
| construct_block = [] | |||
| for idx, node_name in enumerate(node.successors(self.tree_identifier)): | |||
| nd_inst = self.get_node(node_name) | |||
| # Generate code statement. | |||
| init, construct = self._generate_stat(nd_inst, node, idx) | |||
| construct_block.append(construct) | |||
| init_block.append(init) | |||
| class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block) | |||
| csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block) | |||
| csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}" | |||
| cls_definition = f"class {class_name}(nn.Cell):{NEW_LINE * 2}" | |||
| # Mark the structure has been created. | |||
| self._created_module[class_key.lower()] = class_name | |||
| return f"{cls_definition}" \ | |||
| f"{class_init}" \ | |||
| f"{init_body}{NEW_LINE}" \ | |||
| f"{class_construct}" \ | |||
| f"{csrt_body}{csrt_rtn}{NEW_LINE * 2}" | |||
| def _generate_stat(self, cur_nd_inst, pre_nd_inst, idx): | |||
| """ | |||
| Generate statements. | |||
| Args: | |||
| cur_nd_inst (Node): Current node instance. | |||
| pre_nd_inst (Node): Precursor node instance. | |||
| idx (int): Index of cur node. | |||
| Returns: | |||
| Tuple[str, str], declare in init and call in construct. | |||
| """ | |||
| ipt_args_in_construct = "x" | |||
| opt_arg_in_construct = "output" | |||
| if idx != 0: | |||
| # Get previous node output variable name. | |||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||
| if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | |||
| # Set opt variable name. | |||
| opt_arg_in_construct = cur_nd_inst.data.opt_var_name | |||
| declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | |||
| output_var=opt_arg_in_construct) | |||
| return declare, call | |||
| @staticmethod | |||
| def _get_var_name(s): | |||
| """ | |||
| Get variable name using scope name. | |||
| Args: | |||
| s (str): String. | |||
| Returns: | |||
| str, variable name. | |||
| """ | |||
| return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] | |||
| def _get_previous_opt_var(self, cur_nd, pre_nd): | |||
| """ | |||
| Get needed input variable names. | |||
| Args: | |||
| cur_nd (Node): Current node. | |||
| pre_nd (Node): Precursor node. | |||
| Returns: | |||
| str, needed var names. | |||
| """ | |||
| ipt_lst = [] | |||
| if cur_nd.data.node_type == NodeType.OPERATION.value: | |||
| for e in cur_nd.data.precursor_nodes: | |||
| p_nd = self.get_node(e) | |||
| if e not in pre_nd.successors(self.tree_identifier): | |||
| while True: | |||
| if p_nd.identifier in pre_nd.successors(self.tree_identifier): | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| break | |||
| pre_nd_name = p_nd.predecessor(self.tree_identifier) | |||
| if not pre_nd_name: | |||
| ipt_lst.append("x") | |||
| break | |||
| p_nd = self.get_node(pre_nd_name) | |||
| continue | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| else: | |||
| idx = pre_nd.successors(self.tree_identifier).index(cur_nd.identifier) - 1 | |||
| p_nd = self.get_node(pre_nd.successors(self.tree_identifier)[idx]) | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| return ", ".join(ipt_lst) | |||
| def hash_key(self, node): | |||
| """ | |||
| Generate hash key for each node. | |||
| Args: | |||
| node (Node): Node. | |||
| Returns: | |||
| str, hash key. | |||
| """ | |||
| scsr_topo_order = [] | |||
| for s in node.successors(self.tree_identifier): | |||
| cur_nd = self.get_node(s) | |||
| if cur_nd.data.hash_key: | |||
| scsr_topo_order.append(cur_nd.data.hash_key) | |||
| continue | |||
| if cur_nd.data.node_type in {NodeType.MODULE.value, | |||
| NodeType.FUNC.value, | |||
| NodeType.CLASS.value}: | |||
| scsr_topo_order.append(self.hash_key(cur_nd)) | |||
| continue | |||
| unique_key = "->".join(scsr_topo_order) | |||
| node.data.hash_key = unique_key | |||
| return unique_key | |||
| def _module_merging(self, nodes): | |||
| """ | |||
| Generate sub-module and corresponding params. | |||
| Args: | |||
| nodes (List[str]): Nodes name. | |||
| """ | |||
| merged_module = dict() | |||
| merged_module_args = dict() | |||
| for node_name in nodes: | |||
| nd_inst = self.get_node(node_name) | |||
| if nd_inst.data.node_type != NodeType.MODULE.value: | |||
| continue | |||
| module_key = self.hash_key(nd_inst) | |||
| if module_key not in merged_module: | |||
| merged_module[module_key] = [nd_inst.data.args_in_code] | |||
| else: | |||
| merged_module[module_key].append(nd_inst.data.args_in_code) | |||
| for module_key, module_args in merged_module.items(): | |||
| if module_key not in merged_module_args: | |||
| merged_module_args[module_key] = [] | |||
| # Take first element's args as base. | |||
| keys = module_args[0].keys() | |||
| for key in keys: | |||
| for i in range(1, len(module_args)): | |||
| if module_args[0][key] != module_args[i][key]: | |||
| merged_module_args[module_key].append(key) | |||
| break | |||
| self._merged_module.update(merged_module) | |||
| self._merged_module_args.update(merged_module_args) | |||
| def _create_module_args_and_vars(self, node, mapper): | |||
| """ | |||
| Create module args. | |||
| Args: | |||
| node (Node): Node on tree. | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| module_args = dict() | |||
| module_key = self.hash_key(node) | |||
| created = False | |||
| if module_key not in self._args_mgr_in_module: | |||
| self._args_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR | |||
| self._module_vars[module_key] = [] | |||
| else: | |||
| created = True | |||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | |||
| nd_inst = self.get_node(successor_name) | |||
| # Generate variable name here, then | |||
| # to generate args. | |||
| # if nd_inst.data.node_type == NodeType.OPERATION.value: | |||
| if created: | |||
| nd_inst.data.variable_name = self._module_vars[module_key][idx] | |||
| else: | |||
| variable_name = nd_inst.data.op_name or nd_inst.data.module_name | |||
| variable_name = self._args_mgr_in_module[module_key].get_name(variable_name) | |||
| nd_inst.data.variable_name = variable_name | |||
| if nd_inst.data.node_type == NodeType.OPERATION.value: | |||
| # Generation of params must behind variable assigment. | |||
| nd_inst.data.param_transform(mapper) | |||
| module_args.update(nd_inst.data.args_in_code) | |||
| if not created: | |||
| self._module_vars[module_key].append(nd_inst.data.variable_name) | |||
| node.data.args_in_code = module_args | |||
| @staticmethod | |||
| def _create_operation_args(node, mapper): | |||
| """ | |||
| Create operation args. | |||
| Args: | |||
| node (Node): Node on tree. | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| node.data.param_transform(mapper) | |||
| def update_hierarchical_order(self) -> NoReturn: | |||
| """ | |||
| Update hierarchical order. | |||
| """ | |||
| hierarchical_order = dict() | |||
| queue = Queue() | |||
| queue.put(item=(self.root, self.ROOT_LEVEL), block=False) | |||
| while not queue.empty(): | |||
| node_name, cur_level = queue.get(block=False) | |||
| node_inst = self[node_name] | |||
| if cur_level not in hierarchical_order: | |||
| hierarchical_order[cur_level] = [] | |||
| hierarchical_order[cur_level].append(node_name) | |||
| for successor_name in node_inst.successors(self.tree_identifier): | |||
| queue.put(item=(successor_name, cur_level + 1), block=False) | |||
| self._hierarchical_order = hierarchical_order | |||
| def sub_graph_merging(self) -> NoReturn: | |||
| """ | |||
| Shrink subtree. | |||
| """ | |||
| self.update_hierarchical_order() | |||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | |||
| for depth in depths: | |||
| for node_name in self._hierarchical_order[depth]: | |||
| node_inst = self[node_name] | |||
| if not node_inst.data and len(node_inst.successors(self.tree_identifier)) == 1: | |||
| self.shrink(node_inst) | |||
| def _adjust_structure(self) -> NoReturn: | |||
| """ | |||
| Adjust tree structure to generate source code. | |||
| """ | |||
| self.sub_graph_merging() | |||
| self.update_hierarchical_order() | |||
| @staticmethod | |||
| def _get_imported_module(): | |||
| """ | |||
| Generate imported module header. | |||
| Returns: | |||
| str, imported module. | |||
| """ | |||
| return f"from mindspore import nn{NEW_LINE}" \ | |||
| f"from mindspore.ops import operations as P{NEW_LINE * 3}" | |||
| @@ -0,0 +1,98 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Name manager.""" | |||
| import abc | |||
| class NameMgr(abc.ABC): | |||
| """Module name manager.""" | |||
| PLACEHOLDER = 1 | |||
| def __init__(self): | |||
| self.record = dict() | |||
| self.topo_order = [] | |||
| def get_name(self, old_name): | |||
| """ | |||
| Get module/variable name. | |||
| If the module already existed, then add a suffix to it. | |||
| Args: | |||
| old_name (str): Name. | |||
| Returns: | |||
| str, module name. | |||
| """ | |||
| if old_name not in self.record: | |||
| self.record[old_name] = [self.PLACEHOLDER] | |||
| suffix = "" | |||
| else: | |||
| self.record[old_name].append(self.PLACEHOLDER) | |||
| suffix = f"{len(self.record[old_name]) - 1}" | |||
| new_name = f"{old_name}{suffix}" | |||
| self.topo_order.append(new_name) | |||
| return new_name | |||
| class ModuleNameMgr(NameMgr): | |||
| """Module name manager.""" | |||
| class VariableNameMgrInModule(NameMgr): | |||
| """Variable name mgr for a module.""" | |||
| global_op_namespace = dict() | |||
| START_IDX = 0 | |||
| class GlobalVarNameMgr: | |||
| """Global variable name mgr.""" | |||
| @staticmethod | |||
| def _get_name(name): | |||
| """Deal with op name.""" | |||
| if "::" in name: | |||
| return name.split("::")[1] | |||
| return name | |||
| def get_name(self, op_type): | |||
| """ | |||
| Get module/variable name. | |||
| If the module already existed, then add a suffix to it. | |||
| conv1 onnx::conv | |||
| Args: | |||
| op_type (str): Operator type in onnx. | |||
| Returns: | |||
| str, module name. | |||
| """ | |||
| op_type = op_type.lower() | |||
| if op_type not in global_op_namespace: | |||
| global_op_namespace[op_type] = START_IDX | |||
| suffix = "" | |||
| else: | |||
| global_op_namespace[op_type] += 1 | |||
| suffix = f"{global_op_namespace[op_type] - 1}" | |||
| new_name = f"{self._get_name(op_type)}{suffix}" | |||
| return new_name | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from .base import ONNXToMindSporeMapper | |||
| __all__ = [ | |||
| "ONNXToMindSporeMapper" | |||
| ] | |||
| @@ -0,0 +1,118 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import abc | |||
| import importlib | |||
| import json | |||
| import os | |||
| from typing import Dict | |||
| CONFIG_JSON = "onnx_to_ms.json" | |||
| OPERATION_TABLE = os.path.join( | |||
| os.path.abspath(os.path.dirname(__file__)), | |||
| CONFIG_JSON | |||
| ) | |||
| with open(OPERATION_TABLE) as file: | |||
| # Load mapping table which key is operation name in ONNX and | |||
| # value is corresponding module path. | |||
| TABLE = json.load(file) | |||
| # Define global func name. | |||
| GET_OP_NAME = "_operation_name_in_ms" | |||
| GET_OP_PARAMS = "_convert_params" | |||
| GET_OP_WEIGHTS = "_convert_trained_weights" | |||
| class Mapper(metaclass=abc.ABCMeta): | |||
| """Mapper between third-party-operation and MindSpore.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _operation_name_in_ms(): | |||
| """Corresponding operation name in mindspore.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _convert_params(params): | |||
| """Convert third party operation's param into MindSpore operation.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _convert_trained_weights(weights): | |||
| """Convert third party operation's weights into MindSpore operation.""" | |||
| @classmethod | |||
| @abc.abstractmethod | |||
| def convert(cls, op_name: str, params: Dict, weights: Dict = None): | |||
| """Convert third party operation's param into MindSpore operation.""" | |||
| class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| """ONNX operation to MindSpore.""" | |||
| @classmethod | |||
| def convert(cls, op_name: str, params: Dict, weights: Dict = None): | |||
| """ | |||
| Convert third party operation's param into MindSpore operation. | |||
| Args: | |||
| op_name (str): Operation name in ONNX. | |||
| params (dict): Params in onnx. | |||
| weights (dict): Weights in onnx. | |||
| Returns: | |||
| Tuple[str, dict], operation name and params. | |||
| """ | |||
| global TABLE | |||
| module_name = TABLE.get(op_name) | |||
| if not module_name: | |||
| return None, dict() | |||
| pos = module_name.rfind(".") | |||
| try: | |||
| converter = getattr(importlib.import_module(module_name[:pos]), | |||
| module_name[pos + 1:]) | |||
| op_name_converter = getattr(converter, GET_OP_NAME) | |||
| params_converter = getattr(converter, GET_OP_PARAMS) | |||
| weights_converter = getattr(converter, GET_OP_WEIGHTS) | |||
| except (ModuleNotFoundError,) as e: | |||
| # If mapper can not be found, then skip it. | |||
| print(f"Converting {op_name} failed, see {e}") | |||
| return None, dict() | |||
| try: | |||
| converter_name = op_name_converter() | |||
| converted_params = params_converter(params) | |||
| converted_weights = weights_converter(weights) if weights else dict() | |||
| converted_params.update(converted_weights) | |||
| except (AttributeError,) as _: | |||
| print(f"Converting {op_name} failed.") | |||
| return None, dict() | |||
| return converter_name, converted_params | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Implemented mapper module.""" | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Implemented mapper.""" | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class GlobalAvgPoolMapper(ONNXToMindSporeMapper): | |||
| """AvgPool mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.AvgPool2d" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| kernel_size_height = params['input_shape'][2] // params['output_shape'][2] | |||
| kernel_size_width = params['input_shape'][3] // params['output_shape'][3] | |||
| kernel_size = [kernel_size_height, kernel_size_width] | |||
| return { | |||
| 'kernel_size': kernel_size | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class BatchNormMapper(ONNXToMindSporeMapper): | |||
| """BatchNorm mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.BatchNorm2d" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| return { | |||
| 'num_features': params['input_shape'][1], | |||
| 'eps': params['epsilon'], | |||
| 'momentum': params['momentum'] | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class Conv2dMapper(ONNXToMindSporeMapper): | |||
| """Conv2d mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.Conv2d" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| return { | |||
| 'in_channels': params['input_shape'][1], | |||
| 'out_channels': params['output_shape'][1], | |||
| 'kernel_size': params['kernel_shape'], | |||
| 'stride': params['strides'][0], | |||
| 'pad': params['pads'][0], | |||
| 'dilation': params['dilations'][0], | |||
| 'group': params['group']} | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class DenseMapper(ONNXToMindSporeMapper): | |||
| """Dense mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.Dense" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| return { | |||
| 'in_channels': params['input_shape'][1], | |||
| 'out_channels': params['output_shape'][1] | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class FlattenMapper(ONNXToMindSporeMapper): | |||
| """Flatten mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.Flatten" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| if params: | |||
| pass | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class MaxPoolMapper(ONNXToMindSporeMapper): | |||
| """MaxPool mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.MaxPool2d" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| return { | |||
| 'kernel_size': params['kernel_shape'], | |||
| 'stride': params['strides'] | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class ReLUMapper(ONNXToMindSporeMapper): | |||
| """ReLU mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.ReLU" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| if params: | |||
| pass | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Implemented mapper.""" | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class AddMapper(ONNXToMindSporeMapper): | |||
| """Add mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "P.TensorAdd" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| if params: | |||
| pass | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,10 @@ | |||
| { | |||
| "onnx::Conv": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.conv2d_mapper.Conv2dMapper", | |||
| "onnx::Gemm": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.dense_mapper.DenseMapper", | |||
| "onnx::BatchNormalization": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.batch_norm_mapper.BatchNormMapper", | |||
| "onnx::Relu": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper", | |||
| "onnx::MaxPool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.max_pool_mapper.MaxPoolMapper", | |||
| "onnx::GlobalAveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.avg_pool_mapper.GlobalAvgPoolMapper", | |||
| "onnx::Flatten": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.flatten_mapper.FlattenMapper", | |||
| "onnx::Add": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.add_mapper.AddMapper" | |||
| } | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Graph associated definition module.""" | |||
| from .base import Graph | |||
| from .pytorch_graph import PyTorchGraph | |||
| from .pytorch_graph_node import PyTorchGraphNode | |||
| class GraphFactory: | |||
| """Graph factory.""" | |||
| @classmethod | |||
| def init(cls, graph_path: str, sample_shape: tuple, checkpoint: str = None): | |||
| """ | |||
| Init an instance of graph. | |||
| Args: | |||
| graph_path (str): Graph or model file path. | |||
| sample_shape (tuple): Input shape of the model. | |||
| checkpoint (str): Checkpoint file path. | |||
| Returns: | |||
| Graph, graph instance. | |||
| """ | |||
| if checkpoint: | |||
| pass | |||
| return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape) | |||
| __all__ = [ | |||
| "GraphFactory", | |||
| "PyTorchGraphNode", | |||
| ] | |||
| @@ -0,0 +1,547 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define graph entity.""" | |||
| import abc | |||
| from collections import OrderedDict | |||
| from typing import Dict, Union, Any | |||
| from torch.nn import Module | |||
| from ..constant import SEPARATOR_IN_ONNX_OP | |||
| from ..mapper.base import Mapper | |||
| class GraphParser(metaclass=abc.ABCMeta): | |||
| """Graph parser.""" | |||
| @classmethod | |||
| @abc.abstractmethod | |||
| def parse(cls, model_path: str): | |||
| """Parse graph into readable format.""" | |||
| class BaseGraph(metaclass=abc.ABCMeta): | |||
| """Define basic graph.""" | |||
| _REQUIRED_PARAM_OF_MODEL = "model" | |||
| @abc.abstractmethod | |||
| def build(self, input_shape: tuple): | |||
| """Build graph.""" | |||
| @abc.abstractmethod | |||
| def to_ir(self, mapper): | |||
| """Convert graph to ir graph.""" | |||
| @abc.abstractmethod | |||
| def to_hierarchical_tree(self): | |||
| """Convert to hierarchical tree.""" | |||
| @abc.abstractmethod | |||
| def sub_graph_merging(self): | |||
| """Merge split nodes into one.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||
| """Load checkpoint file.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def load_metadata(**kwargs): | |||
| """Load graph metadata.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def load_graph(graph_path: str): | |||
| """Load graph file.""" | |||
| @classmethod | |||
| @abc.abstractmethod | |||
| def load(cls, model_path: str, sample_shape: tuple = None, | |||
| checkpoint: str = None): | |||
| """Factory method to initialize an graph object.""" | |||
| def __new__(cls, *args, **kwargs): | |||
| """Control the create action of graph.""" | |||
| model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL) | |||
| if not model_param: | |||
| raise ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` " | |||
| f"can not be None.") | |||
| return super(BaseGraph, cls).__new__(cls) | |||
| class Graph(BaseGraph, abc.ABC): | |||
| """ | |||
| Define Factory method to create Graph sub-class. | |||
| Args: | |||
| model (Union[Module, Any]): Graph file. | |||
| checkpoint (dict): Checkpoint path. | |||
| """ | |||
| sorted = False | |||
| def __init__(self, model: Union[Module, Any], | |||
| **kwargs): | |||
| super(Graph, self).__init__() | |||
| self.model = model | |||
| self.checkpoint = kwargs.get("checkpoint", None) | |||
| self._nodes_collection = OrderedDict() | |||
| self._nodes_record = dict() | |||
| self._shape_dict = dict() | |||
| self._input_nodes = [] | |||
| self._output_nodes = [] | |||
| self._topological_order = [] | |||
| self._input_shape = dict() | |||
| @property | |||
| def nodes_in_topological_order(self): | |||
| """ | |||
| Return nodes in topological order. | |||
| Returns: | |||
| List[GraphNode], nodes. | |||
| """ | |||
| if not self.sorted: | |||
| self._topological_sort() | |||
| return self._topological_order | |||
| def _reset_topological_order(self): | |||
| """ | |||
| Reset topological order queue. | |||
| """ | |||
| self._topological_order = self._input_nodes[:] | |||
| self.sorted = False | |||
| def get_node(self, node_name): | |||
| """ | |||
| Get node reference. | |||
| Args: | |||
| node_name (str): Node name. | |||
| Returns: | |||
| GraphNode, node instance. | |||
| """ | |||
| prefix = node_name.split(":")[0] | |||
| if prefix not in self._nodes_collection: | |||
| return None | |||
| return self._nodes_collection[prefix] | |||
| def build(self, input_shape: tuple): | |||
| """ | |||
| Build graph. | |||
| Args: | |||
| input_shape (tuple): Input shape of model. | |||
| """ | |||
| # Collect input nodes and output nodes. | |||
| self._collect_ipt_and_opt_nodes() | |||
| # Use topological sort to solve nodes order. | |||
| self._topological_sort() | |||
| def _collect_ipt_and_opt_nodes(self): | |||
| """ | |||
| Collect input and output nodes in model. | |||
| """ | |||
| for name, node in self._nodes_collection.items(): | |||
| if node.in_degree == 0: | |||
| # NOTICE: what's usage of `scope`? | |||
| self._input_nodes.append(name) | |||
| if node.out_degree == 0: | |||
| self._output_nodes.append(name) | |||
| def _topological_sort(self): | |||
| """Topological sort to arrange nodes order.""" | |||
| self._reset_topological_order() | |||
| def is_connected(src, dst): | |||
| """Judge two node whether are connected.""" | |||
| for precursor in dst.precursor_nodes: | |||
| if src == precursor.split(":")[0]: | |||
| return 1 | |||
| return 0 | |||
| idx = 0 | |||
| while idx < len(self._topological_order): | |||
| cur_node_name = self._topological_order[idx] | |||
| cur_node = self.get_node(cur_node_name) | |||
| # `scsr` is abbreviation for `successor`. | |||
| for scsr_name in cur_node.successor_nodes: | |||
| scsr_node = self.get_node(scsr_name) | |||
| scsr_node.cur_in_degree -= is_connected(cur_node_name, | |||
| scsr_node) | |||
| if scsr_node.cur_in_degree == 0: | |||
| self._topological_order.append(scsr_name) | |||
| idx += 1 | |||
| self.sorted = True | |||
| def to_ir(self, mapper): | |||
| raise NotImplementedError | |||
| def to_hierarchical_tree(self): | |||
| raise NotImplementedError | |||
| def sub_graph_merging(self): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def load_metadata(**kwargs): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def load_graph(graph_path: str): | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def load(cls, model_path: str, sample_shape: tuple = None, | |||
| checkpoint: str = None) -> BaseGraph: | |||
| """ | |||
| Load third party graph, metadata and checkpoint. | |||
| Notes: | |||
| `checkpoint` is optional, and it can not be supported currently. | |||
| Args: | |||
| model_path (str): Graph or model file path. | |||
| sample_shape (tuple): Input shape of the model. | |||
| checkpoint (str): Checkpoint file path. | |||
| Returns: | |||
| cls, graph instance. | |||
| """ | |||
| src_graph = cls.load_graph(graph_path=model_path) | |||
| ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None | |||
| if ckpt is not None: | |||
| # Create an instance of TensorflowGraph. | |||
| return cls(model=src_graph, sample_shape=sample_shape, | |||
| checkpoint=ckpt) | |||
| # Create an instance of PyTorchGraph. | |||
| return cls(model=src_graph, sample_shape=sample_shape) | |||
| class GraphNode(abc.ABC): | |||
| """ | |||
| Graph node. | |||
| Args: | |||
| node (torch._C.Node): PyTorch node. | |||
| """ | |||
| transformed = False | |||
| def __init__(self, node): | |||
| # Store the edge from precursor. | |||
| self.precursor_nodes = [] | |||
| # Store the edge to successor. | |||
| self.successor_nodes = [] | |||
| # Control dependency. | |||
| self._deleted_in_edge = 0 | |||
| # Source node in pytorch. | |||
| self._src_node = str(node) if node else None | |||
| # Original operation name in pytorch. | |||
| self._op_name = None | |||
| self._op_params = dict() | |||
| self._scope_name = None | |||
| self._op_shape = None | |||
| # Operation in mindspore. | |||
| self._op_in_ms = None | |||
| # Params in mindspore. | |||
| self._params_in_ms = dict() | |||
| # Node type of current node, e.g. class, module, operation. | |||
| self._node_type = None | |||
| # Tag name on tree. | |||
| self._tag_on_tree = None | |||
| # Function, class or operation needed args. | |||
| self._args_in_code = dict() | |||
| # Variable name declared in init block. | |||
| self._variable_name = None | |||
| # Output variable name declared in construct block. | |||
| self._opt_var_name = None | |||
| # Function or class name in code. | |||
| self._module_name = None | |||
| # Unique key of node. | |||
| self._hash_key = None | |||
| # Input shape of current op. | |||
| self._ipt_shape = None | |||
| # Output shape of current op. | |||
| self._opt_shape = None | |||
| @property | |||
| def opt_var_name(self): | |||
| """ | |||
| Output variable name. | |||
| Returns: | |||
| str, variable name. | |||
| """ | |||
| return f"{self.variable_name}_opt" | |||
| @opt_var_name.setter | |||
| def opt_var_name(self, v): | |||
| """ | |||
| Set variable name. | |||
| Args: | |||
| v (str): Name. | |||
| """ | |||
| self._opt_var_name = v | |||
| @property | |||
| def op_in_ms(self): | |||
| """ | |||
| Operation in mindspore. | |||
| Returns: | |||
| str, operation name. | |||
| """ | |||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP in self._op_in_ms: | |||
| return self._op_in_ms.replace(SEPARATOR_IN_ONNX_OP, ".") | |||
| return self._op_in_ms | |||
| @property | |||
| def args_in_code(self): | |||
| """ | |||
| Args in code. | |||
| Returns: | |||
| dict, args. | |||
| """ | |||
| return self._args_in_code | |||
| @args_in_code.setter | |||
| def args_in_code(self, args): | |||
| """ | |||
| Setter for args_in_code. | |||
| Args: | |||
| args (dict): Args. | |||
| """ | |||
| self._args_in_code = args | |||
| @property | |||
| def input_shape(self): | |||
| """ | |||
| Input tensor shape of current node. | |||
| Returns: | |||
| tuple, tensor shape of input. | |||
| """ | |||
| return self._ipt_shape | |||
| @property | |||
| def output_shape(self): | |||
| """ | |||
| Output tensor shape. | |||
| Returns: | |||
| tuple, output tensor shape. | |||
| """ | |||
| return self._opt_shape | |||
| @property | |||
| def tag(self): | |||
| """Tag on hierarchical tree.""" | |||
| return self._tag_on_tree | |||
| @tag.setter | |||
| def tag(self, t): | |||
| """Tag on hierarchical tree.""" | |||
| self._tag_on_tree = t | |||
| def is_empty(self): | |||
| """ | |||
| Whether is empty. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| return not self._src_node | |||
| @property | |||
| def node_type(self): | |||
| """Get node type (ONNX op type).""" | |||
| return self._node_type | |||
| @node_type.setter | |||
| def node_type(self, m): | |||
| """ | |||
| Setter of node_type. | |||
| Args: | |||
| m (str): Node type. | |||
| """ | |||
| self._node_type = m | |||
| @property | |||
| def scope_name(self): | |||
| """ | |||
| Scope name. | |||
| Returns: | |||
| str, scope name. | |||
| """ | |||
| return self._scope_name | |||
| @property | |||
| def node_params(self): | |||
| """Get node params (ONNX op params).""" | |||
| return self._op_params | |||
| @property | |||
| def cur_in_degree(self): | |||
| """ | |||
| Current in-degree. | |||
| Returns: | |||
| int, current in-degree. | |||
| """ | |||
| return self.in_degree - self._deleted_in_edge | |||
| @cur_in_degree.setter | |||
| def cur_in_degree(self, e): | |||
| """ | |||
| Setter of cur_in_degree. | |||
| Args: | |||
| e (int): To be update value. | |||
| """ | |||
| self._deleted_in_edge += self.cur_in_degree - e | |||
| @property | |||
| def in_degree(self): | |||
| """ | |||
| Define in-degree. | |||
| Returns: | |||
| int, in-degree. | |||
| """ | |||
| return len(self.precursor_nodes) | |||
| @property | |||
| def out_degree(self): | |||
| """ | |||
| Define out-degree. | |||
| Returns: | |||
| int, out-degree. | |||
| """ | |||
| return len(self.successor_nodes) | |||
| @property | |||
| @abc.abstractmethod | |||
| def hash_key(self): | |||
| """ | |||
| Generate unique hash key for each node. | |||
| Use topological order as key. | |||
| """ | |||
| @abc.abstractmethod | |||
| def _get_raw_params(self, node): | |||
| """Get params in onnx.""" | |||
| @property | |||
| @abc.abstractmethod | |||
| def op_name(self): | |||
| """Return op_name.""" | |||
| @abc.abstractmethod | |||
| def replace_with_arg(self, arg): | |||
| """Replace actual parameter with formal parameter.""" | |||
| @abc.abstractmethod | |||
| def _get_arg_name(self, arg): | |||
| """Get arg name for func or class.""" | |||
| @abc.abstractmethod | |||
| def clear_args_of_declaration(self): | |||
| """Clear `_args_in_code`.""" | |||
| @property | |||
| @abc.abstractmethod | |||
| def real_name(self): | |||
| """Getter of `real_name`.""" | |||
| @real_name.setter | |||
| @abc.abstractmethod | |||
| def real_name(self, **kwargs): | |||
| """Setter of `real_name`.""" | |||
| @property | |||
| @abc.abstractmethod | |||
| def variable_name(self): | |||
| """Getter of `variable_name`.""" | |||
| @abc.abstractmethod | |||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||
| """Graph node to MindSpore code.""" | |||
| @abc.abstractmethod | |||
| def to_ir(self): | |||
| """Graph node to ir node.""" | |||
| @abc.abstractmethod | |||
| def add_input_and_output_shape(self, input_shape, output_shape): | |||
| """Add the node input shape.""" | |||
| @abc.abstractmethod | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| """Make node_type can not be changed.""" | |||
| @abc.abstractmethod | |||
| def convert_successful(self): | |||
| """Whether convert successful.""" | |||
| def param_transform(self, mapper: Mapper): | |||
| """ | |||
| Transform param in pytorch operation into mindspore. | |||
| Args: | |||
| mapper (ONNXToMindSporeMapper): Mapper between onnx operation | |||
| and mindspore. | |||
| Returns: | |||
| dict, transformed params. | |||
| """ | |||
| import copy | |||
| params = copy.deepcopy(self._op_params) | |||
| params.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| op_name_in_mindspore, ms_params = mapper.convert(op_name=self.op_name, | |||
| params=params) | |||
| if op_name_in_mindspore: | |||
| self._op_in_ms = op_name_in_mindspore | |||
| self._params_in_ms = ms_params | |||
| else: | |||
| self._op_in_ms = self._op_name | |||
| self._params_in_ms = self._op_params | |||
| return self._op_in_ms, self._params_in_ms | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Third party graph parser.""" | |||
| import os | |||
| from .base import GraphParser | |||
| class PyTorchGraphParser(GraphParser): | |||
| """Define pytorch graph parser.""" | |||
| @classmethod | |||
| def parse(cls, model_path: str): | |||
| """ | |||
| Parser pytorch graph. | |||
| Args: | |||
| model_path (str): Model file path. | |||
| Returns: | |||
| object, torch model. | |||
| """ | |||
| import torch | |||
| if not os.path.exists(model_path): | |||
| raise FileNotFoundError("`model_path` must be assigned with " | |||
| "an existed file path.") | |||
| if torch.cuda.is_available(): | |||
| model = torch.load(f=model_path) | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| return model | |||
| @@ -0,0 +1,109 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph node.""" | |||
| import os | |||
| from .base import GraphNode | |||
| from ..constant import SEPARATOR_IN_SCOPE, NodeType | |||
| class InputNode(GraphNode): | |||
| """ | |||
| Pytorch Input Node. | |||
| Args: | |||
| input_shape: Input shape of module. | |||
| """ | |||
| def convert_successful(self): | |||
| """ | |||
| Whether convert successful. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| return False | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| pass | |||
| def _get_raw_params(self, node): | |||
| pass | |||
| def clear_args_of_declaration(self): | |||
| pass | |||
| @property | |||
| def op_name(self): | |||
| return self._op_name | |||
| def hash_key(self): | |||
| pass | |||
| def replace_with_arg(self, arg): | |||
| pass | |||
| def _get_arg_name(self, arg): | |||
| pass | |||
| def add_input_and_output_shape(self, input_shape, output_shape): | |||
| pass | |||
| def __init__(self, input_shape): | |||
| super(InputNode, self).__init__(node=None) | |||
| self._op_name = 'Input' | |||
| self._op_params = {'node_shape': input_shape} | |||
| self._node_type = NodeType.INPUT.value | |||
| def set_scope_name(self, original_input_scope_name): | |||
| """ | |||
| Set scope name. | |||
| Args: | |||
| original_input_scope_name: Original input scope name needed to be linked. | |||
| """ | |||
| prefix_name = original_input_scope_name.split(SEPARATOR_IN_SCOPE)[0] | |||
| node_name = ''.join((self.node_type, '[input]')) | |||
| self._scope_name = os.path.join(prefix_name, node_name) | |||
| def set_successor_nodes(self, original_input_scope_names): | |||
| """ | |||
| Set successor nodes. | |||
| Args: | |||
| original_input_scope_names: Original input scope names needed to be linked. | |||
| """ | |||
| if isinstance(original_input_scope_names, list): | |||
| self.successor_nodes = original_input_scope_names | |||
| elif isinstance(original_input_scope_names, str): | |||
| self.successor_nodes.append(original_input_scope_names) | |||
| else: | |||
| raise ValueError | |||
| @property | |||
| def real_name(self): | |||
| return | |||
| @property | |||
| def variable_name(self): | |||
| return | |||
| def to_ir(self): | |||
| """ | |||
| No need to implement for now. | |||
| """ | |||
| raise NotImplementedError() | |||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||
| raise NotImplementedError() | |||
| @@ -0,0 +1,268 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph.""" | |||
| import platform | |||
| import warnings | |||
| import re | |||
| from typing import Dict, NoReturn | |||
| import torch | |||
| from torch.nn import Module | |||
| from torch.onnx import OperatorExportTypes | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .pytorch_graph_node import PyTorchGraphNode | |||
| from .graph_parser import PyTorchGraphParser | |||
| from .torch_utils import OverloadTorchModuleTemporarily, unique_state_dict | |||
| from .torch_utils import create_autograd_variable | |||
| from .torch_utils import onnx_tracer | |||
| from ..hierarchical_tree import HierarchicalTree | |||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE | |||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | |||
| NONE_SCOPE_OP = { | |||
| 'onnx::Add': 'Add', | |||
| 'onnx::Flatten': 'Flatten', | |||
| } | |||
| def normalize_scope_name(node): | |||
| """ | |||
| Rename scope name into uniform. | |||
| Args: | |||
| node (Node): PyTorch node. | |||
| Returns: | |||
| str, normalized scope name. | |||
| """ | |||
| global NONE_SCOPE_OP | |||
| name = node.scopeName().split(SEPARATOR_IN_SCOPE) | |||
| scopes = [] | |||
| for segment in name: | |||
| segment = segment.split(LINK_IN_SCOPE)[0] | |||
| left = segment.find(LEFT_BUCKET) | |||
| right = segment.find(RIGHT_BUCKET) | |||
| if left != -1: | |||
| if segment[left + 1: right].isdigit(): | |||
| scopes.append(f"{segment[:left]}_{segment[left + 1: right]}") | |||
| else: | |||
| scopes.append(segment[left + 1: right]) | |||
| else: | |||
| scopes.append(segment) | |||
| if node.kind() in NONE_SCOPE_OP.keys(): | |||
| scopes.append(NONE_SCOPE_OP[node.kind()]) | |||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}" | |||
| class PyTorchGraph(Graph): | |||
| """ | |||
| Define PyTorch graph. | |||
| Args: | |||
| model (Module): PyTorch model. | |||
| sample_shape (tuple): Input shape of the model. | |||
| """ | |||
| def __init__(self, model: Module, sample_shape: tuple): | |||
| super(PyTorchGraph, self).__init__(model=model) | |||
| self._params_dict = unique_state_dict(model) | |||
| self.build(sample_shape) | |||
| @staticmethod | |||
| def _check_input_shape(input_shape): | |||
| """ | |||
| Check input shape. | |||
| Args: | |||
| input_shape (tuple): Input tensor shape. | |||
| """ | |||
| if not input_shape: | |||
| raise ValueError("`input_shape` can not be None.") | |||
| for item in input_shape: | |||
| if not isinstance(item, int): | |||
| raise ValueError(f"Only support model with one input now, " | |||
| f"and each shape value in `input_shape` should be int.") | |||
| def build(self, input_shape): | |||
| """ | |||
| Build graph tree. | |||
| Args: | |||
| input_shape (tuple): Input shape of model. | |||
| """ | |||
| self._check_input_shape(input_shape) | |||
| def _extract_shape(shape): | |||
| if platform.system() == "Darwin": | |||
| return [int(x.split(":")[0]) for x in shape.split(',')] | |||
| return [int(x.replace("!", "")) for x in shape.split(',')] | |||
| feed_forward_ipt_shape = (1, *input_shape) | |||
| batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape)) | |||
| # Assign execution mode to eval. | |||
| self.model.eval() | |||
| with OverloadTorchModuleTemporarily() as _: | |||
| # In pytorch higher version, trace function has a known. | |||
| graph = onnx_tracer(self.model, batched_sample, | |||
| OperatorExportTypes.ONNX) | |||
| nodes = list(graph.nodes()) | |||
| for node in nodes: | |||
| node_name = normalize_scope_name(node) | |||
| output_shape_str_list = re.findall(r'[^()!]+', str(node)) | |||
| output_shape_str = output_shape_str_list[1] | |||
| output_shape = _extract_shape(output_shape_str) | |||
| self._shape_dict[node_name] = output_shape | |||
| self._nodes_collection[node_name] = PyTorchGraphNode(node) | |||
| self._nodes_record[node_name] = node_name | |||
| for node_input in list(node.inputs()): | |||
| # Connect input node and src node. | |||
| if PyTorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName(): | |||
| node_input_name = normalize_scope_name( | |||
| node_input.node() | |||
| ) | |||
| self.build_connection(node_input_name, node_name) | |||
| super(PyTorchGraph, self).build(input_shape=input_shape) | |||
| # Add Input Node | |||
| input_node = InputNode(input_shape) | |||
| for node_name, node in self._nodes_collection.items(): | |||
| if node_name in self._input_nodes: | |||
| input_node.set_scope_name(node.scope_name) | |||
| node.precursor_nodes.append(input_node.scope_name) | |||
| input_node.set_successor_nodes(node_name) | |||
| self._nodes_collection[input_node.scope_name] = input_node | |||
| self._input_shape[node_name] = feed_forward_ipt_shape | |||
| break | |||
| def sub_graph_merging(self): | |||
| """ | |||
| Merge split operation into one. | |||
| """ | |||
| raise NotImplementedError() | |||
| def to_ir(self, mapper): | |||
| """ | |||
| Convert graph to IR graph. | |||
| """ | |||
| raise NotImplementedError() | |||
| def to_hierarchical_tree(self): | |||
| """ | |||
| Generate hierarchical tree based on graph. | |||
| """ | |||
| tree = HierarchicalTree() | |||
| node_input = None | |||
| for _, node_name in enumerate(self.nodes_in_topological_order): | |||
| node_inst = self.get_node(node_name) | |||
| node_output = self._shape_dict.get(node_name) | |||
| if node_inst.in_degree == 0: | |||
| # If in-degree equals to zero, then it's a input node. | |||
| continue | |||
| # If the node is on the top, then fetch its input | |||
| # from input table. | |||
| if not node_input: | |||
| node_input = self._input_shape.get(node_name) | |||
| if not node_input: | |||
| raise ValueError(f"Cannot find {node_name}'s input shape.") | |||
| tree.insert(node_inst, node_name, node_input, node_output) | |||
| node_input = node_output | |||
| return tree | |||
| def build_connection(self, src, tgt) -> NoReturn: | |||
| """ | |||
| Build connection between source node and target node. | |||
| Args: | |||
| src (str): Source node name. | |||
| tgt (str): Target node name. | |||
| """ | |||
| # If src and tgt are the same node, src not in node_collection or | |||
| # tgt not in node_collection, | |||
| # then skip this edge. | |||
| if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: | |||
| if src.split(':')[0] not in self._nodes_collection: | |||
| warnings.warn(f"Graph construct a self-loop node {src}. Ignored.") | |||
| return | |||
| if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: | |||
| self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) | |||
| if src not in self._nodes_collection[tgt].precursor_nodes: | |||
| self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src) | |||
| @staticmethod | |||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||
| """ | |||
| Load checkpoint. | |||
| Args: | |||
| ckpt_path (str): Checkpoint file path. | |||
| Returns: | |||
| dict, weights in model. | |||
| """ | |||
| @staticmethod | |||
| def load_metadata(**kwargs): | |||
| """ | |||
| Load graph metadata. | |||
| """ | |||
| raise NotImplementedError("class `PyTorchGraph` has not implemented " | |||
| "`load_metadata()`.") | |||
| @staticmethod | |||
| def load_graph(graph_path: str): | |||
| """ | |||
| Load graph. | |||
| Args: | |||
| graph_path (str): Graph path. | |||
| Returns: | |||
| object, pytorch model. | |||
| """ | |||
| torch_model = PyTorchGraphParser.parse(graph_path) | |||
| return torch_model | |||
| @staticmethod | |||
| def get_node_id(node): | |||
| """ | |||
| Get node id using regular expr. | |||
| Args: | |||
| node (Node): PyTorch node. | |||
| Returns: | |||
| str, node id. | |||
| """ | |||
| node_id = re.search(r"[\d]+", str(node)) | |||
| return node_id.group() | |||
| @@ -0,0 +1,282 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph node.""" | |||
| from .base import GraphNode | |||
| from .torch_utils import getitem_of_node | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| SEPARATOR_IN_ONNX_OP | |||
| from ..mapper.base import Mapper | |||
| class PyTorchGraphNode(GraphNode): | |||
| """ | |||
| PyTorch graph node. | |||
| Args: | |||
| node (torch._C.Node): Node in raw PyTorch graph. | |||
| """ | |||
| _type_frozen = False | |||
| _module_name_frozen = False | |||
| def __init__(self, node=None): | |||
| super(PyTorchGraphNode, self).__init__(node=node) | |||
| self._op_params = self._get_raw_params(node) | |||
| self._op_name = node.kind() if node else None | |||
| self._scope_name = node.scopeName() if node else None | |||
| self._opt_var_name = None | |||
| self._variable_name = self._extract_var_name(self._scope_name) | |||
| self._module_name = None | |||
| def clear_args_of_declaration(self): | |||
| """ | |||
| Clear `self._args_in_code`. | |||
| """ | |||
| self._args_in_code = dict() | |||
| def _get_arg_name(self, arg): | |||
| """ | |||
| Get arg name. | |||
| Args: | |||
| arg (str): Generate arg name. | |||
| Returns: | |||
| str, arg name in function or class declaration. | |||
| """ | |||
| return f"{arg}_{self._variable_name}" | |||
| @property | |||
| def hash_key(self): | |||
| """ | |||
| Return unique hash key of current node. | |||
| Returns: | |||
| str, hash key. | |||
| """ | |||
| if self._node_type not in {NodeType.CLASS.value, | |||
| NodeType.FUNC.value, | |||
| NodeType.MODULE.value}: | |||
| self._hash_key = self._op_name.lower() | |||
| return self._hash_key | |||
| @hash_key.setter | |||
| def hash_key(self, h): | |||
| """ | |||
| Setter of hash key. | |||
| Args: | |||
| h (str): Key. | |||
| """ | |||
| self._hash_key = h | |||
| @property | |||
| def variable_name(self): | |||
| """ | |||
| Variable name. | |||
| Returns: | |||
| str, variable name declared in init. | |||
| """ | |||
| return self._variable_name | |||
| @variable_name.setter | |||
| def variable_name(self, v): | |||
| """ | |||
| Setter of variable name. | |||
| Args: | |||
| v (str): Variable name. | |||
| """ | |||
| self._variable_name = v | |||
| @property | |||
| def module_name(self): | |||
| """ | |||
| Module name. | |||
| Returns: | |||
| str, module name. | |||
| """ | |||
| if not self._module_name_frozen: | |||
| module_name = self.tag | |||
| # if self._node_type == NodeType.CLASS.value: | |||
| # module_name = f"{module_name[0].upper()}{module_name[1:]}" | |||
| return module_name | |||
| return self._module_name | |||
| def _froze_module_name(self, m): | |||
| """ | |||
| Once module_name is set, then it's unchangeable. | |||
| Args: | |||
| m (str): Module name. | |||
| """ | |||
| if not self._module_name_frozen: | |||
| self._module_name = m | |||
| self._module_name_frozen = True | |||
| @property | |||
| def op_name(self): | |||
| """ | |||
| Op name in torch. | |||
| Returns: | |||
| str, op name. | |||
| """ | |||
| return self._op_name # if self.is_empty() else self.tag | |||
| @property | |||
| def real_name(self): | |||
| return | |||
| def add_input_and_output_shape(self, input_shape, output_shape): | |||
| """ | |||
| Add the node input shape. | |||
| Args: | |||
| output_shape (tuple): Output tensor shape. | |||
| input_shape (tuple): Input tensor shape. | |||
| """ | |||
| self._ipt_shape = input_shape | |||
| self._opt_shape = output_shape | |||
| def to_code(self, ipt_args_in_construct: str, output_var: str): | |||
| """ | |||
| Generate statements. | |||
| Args: | |||
| ipt_args_in_construct (str): Args of input. | |||
| output_var (str): Output variable name in construct. | |||
| Returns: | |||
| Union[str, str], declare in init and call in construct. | |||
| """ | |||
| operator = self.op_in_ms or self.module_name | |||
| self._opt_var_name = output_var | |||
| args = self.args_in_code | |||
| if self._node_type == NodeType.OPERATION.value and not self.convert_successful(): | |||
| args.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| declare = f"self.{self._variable_name} = {operator}({expr})" | |||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})" | |||
| return declare, call | |||
| def to_ir(self): | |||
| """ | |||
| No need to implement for now. | |||
| """ | |||
| raise NotImplementedError | |||
| def _get_raw_params(self, node): | |||
| """ | |||
| Get params in onnx. | |||
| Args: | |||
| node (Any): Node. | |||
| Returns: | |||
| dict, raw params. | |||
| """ | |||
| raw_params = dict() | |||
| if not node: | |||
| return raw_params | |||
| for k in node.attributeNames(): | |||
| raw_params[k] = getitem_of_node(node, k) | |||
| return raw_params | |||
| def replace_with_arg(self, arg): | |||
| """ | |||
| Replace actual parameter with formal parameter. | |||
| Args: | |||
| arg (str): Arg name. | |||
| """ | |||
| self._args_in_code[arg] = arg | |||
| @staticmethod | |||
| def _extract_var_name(scope_name: str): | |||
| """ | |||
| Extract variable name from scope name. | |||
| """ | |||
| if not scope_name: | |||
| return None | |||
| var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower() | |||
| var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace( | |||
| RIGHT_BUCKET, "") | |||
| return var | |||
| def param_transform(self, mapper: Mapper): | |||
| """ | |||
| Transform torch params into mindspore. | |||
| Args: | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| if not self.transformed: | |||
| _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | |||
| for arg, value in self._params_in_ms.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| self.transformed = True | |||
| return self._op_in_ms, self._params_in_ms | |||
| def froze_node_type_and_module_name(self, node_type, module_name): | |||
| """ | |||
| Froze node type and module name. | |||
| After node_type is frozen, then the `module_name` | |||
| will be affected when `node_type` is `class`. | |||
| Thus, this line must be placed before `nd_inst.data.module_name`. | |||
| Args: | |||
| module_name: Modified module name. | |||
| node_type (str): Node type, class of func. | |||
| """ | |||
| if not self._type_frozen: | |||
| self._node_type = node_type | |||
| self._type_frozen = True | |||
| if not self._module_name_frozen: | |||
| self._froze_module_name(module_name) | |||
| def convert_successful(self): | |||
| """ | |||
| Whether convert successfully. | |||
| Returns: | |||
| bool, true or false. | |||
| """ | |||
| if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms: | |||
| return True | |||
| return False | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define pytorch tracer context manager.""" | |||
| import importlib | |||
| from torch.nn import Module | |||
| from torch.jit import _unique_state_dict | |||
| from torch.onnx.utils import _trace | |||
| from torch.onnx.utils import _node_getitem | |||
| SCRIPT_METHOD = getattr(importlib.import_module("torch._C"), | |||
| "ScriptMethod") | |||
| onnx_tracer = _trace | |||
| getitem_of_node = _node_getitem | |||
| def unique_state_dict(model): | |||
| """ | |||
| Wrapper of torch.jit._unique_state_dict. | |||
| Args: | |||
| model (Module): Torch model. | |||
| Returns: | |||
| dict, params. | |||
| """ | |||
| return _unique_state_dict(model) | |||
| def create_autograd_variable(tensor): | |||
| """ | |||
| Create autograd variable to trace the whole graph. | |||
| Args: | |||
| tensor (torch.Tensor): Tensor. | |||
| Returns: | |||
| torch.autograd.Variable, variable. | |||
| """ | |||
| variable = getattr(importlib.import_module("torch.autograd"), "Variable") | |||
| return variable(tensor, requires_grad=False) | |||
| class OverloadTorchModuleTemporarily: | |||
| """ | |||
| Fix bugs in new version of pytorch. | |||
| PyTorch official solution. | |||
| """ | |||
| def __init__(self): | |||
| self.backup = None | |||
| def __enter__(self): | |||
| def _tracing_name(traced_module, tracing_state): | |||
| traced_module_stack = getattr(tracing_state, "_traced_module_stack") | |||
| if not traced_module_stack: | |||
| return None | |||
| module = traced_module_stack[-1] | |||
| for name, child in module.named_children(): | |||
| if child is traced_module: | |||
| return name | |||
| return None | |||
| def _slow_forward(self_, *inputs, **kwargs): | |||
| tracing_state = getattr(importlib.import_module("torch._C"), | |||
| "_get_tracing_state")() | |||
| if not tracing_state or isinstance(self_.forward, SCRIPT_METHOD): | |||
| return self_.forward(*inputs, **kwargs) | |||
| if not hasattr(tracing_state, '_traced_module_stack'): | |||
| tracing_state._traced_module_stack = [] | |||
| name = _tracing_name(self_, tracing_state) | |||
| get_name_func = getattr(self_, "_get_name") | |||
| if name: | |||
| tracing_state.push_scope('%s[%s]' % (get_name_func(), name)) | |||
| else: | |||
| tracing_state.push_scope(get_name_func()) | |||
| tracing_state._traced_module_stack.append(self_) | |||
| try: | |||
| result = self_.forward(*inputs, **kwargs) | |||
| finally: | |||
| tracing_state.pop_scope() | |||
| tracing_state._traced_module_stack.pop() | |||
| return result | |||
| self.backup = getattr(Module, "_slow_forward") | |||
| setattr(Module, '_slow_forward', _slow_forward) | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| setattr(Module, '_slow_forward', self.backup) | |||
| @@ -14,4 +14,6 @@ psutil>=5.6.1 | |||
| six>=1.12.0 | |||
| Werkzeug>=1.0.0 | |||
| tabulate>=0.8.6 | |||
| pandas>=1.0.4 | |||
| pandas>=1.0.4 | |||
| yapf>=0.30.0 | |||
| treelib>=1.6.1 | |||