Browse Source

Graph based converter.

tags/v1.0.0
liuchongming74 5 years ago
parent
commit
3a66a65a60
29 changed files with 2864 additions and 1 deletions
  1. +4
    -0
      .gitignore
  2. +18
    -0
      mindinsight/mindconverter/graph_based_converter/__init__.py
  3. +42
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +101
    -0
      mindinsight/mindconverter/graph_based_converter/framework.py
  5. +20
    -0
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  6. +687
    -0
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  7. +98
    -0
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py
  8. +20
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/__init__.py
  9. +118
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  10. +15
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/__init__.py
  11. +15
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/__init__.py
  12. +39
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/avg_pool_mapper.py
  13. +38
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  14. +41
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv2d_mapper.py
  15. +37
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  16. +36
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py
  17. +37
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/max_pool_mapper.py
  18. +36
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py
  19. +15
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/__init__.py
  20. +36
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py
  21. +10
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json
  22. +46
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  23. +547
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  24. +45
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py
  25. +109
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py
  26. +268
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  27. +282
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py
  28. +101
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py
  29. +3
    -1
      requirements.txt

+ 4
- 0
.gitignore View File

@@ -78,6 +78,10 @@ TESTS*.xml
# vscode settings
.vscode


# OS files
*.DS_Store

package-lock.json

build/*


+ 18
- 0
mindinsight/mindconverter/graph_based_converter/__init__.py View File

@@ -0,0 +1,18 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Graph based scripts converter definition."""
from .framework import graph_based_converter

__all__ = ["graph_based_converter"]

+ 42
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constant definition."""
from enum import Enum, unique

SEPARATOR_IN_ONNX_OP = "::"
SEPARATOR_IN_SCOPE = "/"
SEPARATOR_BTW_NAME_AND_ID = "_"
LINK_IN_SCOPE = "-"
LEFT_BUCKET = "["
RIGHT_BUCKET = "]"

BLANK_SYM = " "
FIRST_LEVEL_INDENT = BLANK_SYM * 4
SECOND_LEVEL_INDENT = BLANK_SYM * 8
NEW_LINE = "\n"


@unique
class CodeFormatConfig(Enum):
PEP8 = "pep8"


@unique
class NodeType(Enum):
MODULE = "module"
OPERATION = "operation"
CLASS = "class"
FUNC = "func"
INPUT = "DataInput"

+ 101
- 0
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Graph based scripts converter workflow."""
import os
import argparse
from importlib.util import find_spec

import mindinsight
from .mapper import ONNXToMindSporeMapper

permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)

parser = argparse.ArgumentParser(
prog="MindConverter",
description="Graph based MindConverter CLI entry point (version: {})".format(
mindinsight.__version__)
)

parser.add_argument("--graph", type=str, required=True,
help="Third party framework's graph path.")
parser.add_argument("--sample_shape", nargs='+', type=int, required=True,
help="Input shape of the model.")
parser.add_argument("--ckpt", type=str, required=False,
help="Third party framework's checkpoint path.")
parser.add_argument("--output", type=str, required=True,
help="Generated scripts output folder path.")
parser.add_argument("--report", type=str, required=False,
help="Generated reports output folder path.")


def torch_installation_validation(func):
"""
Validate args of func.

Args:
func (type): Function.

Returns:
type, inner function.
"""

def _f(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
checkpoint_path: str = None):
# Check whether pytorch is installed.
if not find_spec("torch"):
raise ModuleNotFoundError("PyTorch is required when using graph based "
"scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.")

func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder,
checkpoint_path=checkpoint_path)

return _f


@torch_installation_validation
def graph_based_converter(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
checkpoint_path: str = None):
"""
Graph based scripts converter.

Args:
graph_path (str): Graph file path.
sample_shape (tuple): Input shape of the model.
output_folder (str): Output folder.
report_folder (str): Report output folder path.
checkpoint_path (str): Checkpoint file path.

"""
from .third_party_graph import GraphFactory

graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
checkpoint=checkpoint_path)
hierarchical_tree = graph_obj.to_hierarchical_tree()
hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
report_folder=report_folder)


if __name__ == '__main__':
args, _ = parser.parse_known_args()
graph_based_converter(graph_path=args.graph,
sample_shape=args.sample_shape,
output_folder=args.output,
report_folder=args.report,
checkpoint_path=args.ckpt)

+ 20
- 0
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hierarchical tree module."""
from .hierarchical_tree import HierarchicalTree

__all__ = [
"HierarchicalTree"
]

+ 687
- 0
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -0,0 +1,687 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define hierarchical tree."""
import os
from copy import deepcopy
from typing import NoReturn, Union
from queue import Queue

from yapf.yapflib.yapf_api import FormatCode
from treelib import Tree, Node

from .name_mgr import ModuleNameMgr, GlobalVarNameMgr
from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT, CodeFormatConfig
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
from ..constant import NodeType

GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr()


class HierarchicalTree(Tree):
"""Define hierarchical tree."""
_root_created = False
ROOT_LEVEL = 0

def __init__(self):
super(HierarchicalTree, self).__init__()
self._hierarchical_order = dict()
# Manage mapping of unique key and module name.
self._merged_module = dict()
# Manage mapping of unique key and module args.
self._merged_module_args = dict()
# Record creation of module with unique key.
self._created_module = dict()
# Manage module name to used.
self._module_mgr = ModuleNameMgr()
# Manage variable name in a module.
self._args_mgr_in_module = dict()
self._module_vars = dict()

@property
def tree_identifier(self):
"""
Return identifier of tree.

Returns:
tree, id of tree.
"""
return self.identifier

def insert(self, node: PyTorchGraphNode, node_name: str, input_shape, output_shape):
"""
Insert node into hierarchical tree.

Args:
node_name (str): Node name.
node (PyTorchGraphNode): Node to be inserted.
output_shape (tuple): Output tensor shape.
input_shape (tuple): Input tensor shape.

"""
node.add_input_and_output_shape(input_shape, output_shape)
scopes = node_name.split(SEPARATOR_IN_SCOPE)
for idx, scope in enumerate(scopes):
parent = SEPARATOR_IN_SCOPE.join(scopes[:idx])
identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1])
try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \
if not parent else scope
if self.contains(try_parent):
# Whether current node existed.
parent = try_parent

if not parent and not self._root_created:
# If no root node, then create it and mark it.
parent = None
self._root_created = True
elif not parent and self._root_created:
# Already have root node, skip it.
continue

if not self.contains(identifier):
# Insert node into tree.
tgt_node = node if idx == len(scopes) - 1 else PyTorchGraphNode()
tgt_node.successor_nodes = node.successor_nodes
tgt_node.precursor_nodes = node.precursor_nodes
tgt_node.node_type = (NodeType.OPERATION if idx == len(scopes) - 1
else NodeType.MODULE).value
tgt_node.tag = scope.split(SEPARATOR_BTW_NAME_AND_ID)[0]
tgt_node.variable_name = self._get_var_name(identifier)
self.create_node(
tag=tgt_node.tag,
identifier=identifier,
parent=parent,
data=tgt_node
)

def remove(self, node: Node, keep_sub=False):
"""
Remove node into hierarchical tree.

Args:
node (Node): Node to be removed.
keep_sub (bool): Whether keep sub-tree.

"""
if not keep_sub:
self.remove_node(node.identifier)
return

def shrink(self, node: Node):
"""
Shrink sub-tree into one node.

Args:
node (Node): List of nodes to be merged.

"""
node_name = node.identifier
parent_node = self[node.predecessor(self.tree_identifier)]
# Keep successors of parent.
brothers = deepcopy(parent_node.successors(self.tree_identifier))
child = node.successors(self.tree_identifier)[0]
self.move_node(source=child,
destination=node.predecessor(self.tree_identifier))
self.remove(node)
brothers[brothers.index(node_name)] = child
parent_node.set_successors(brothers, tree_id=self.tree_identifier)

def save_source_files(self, out_folder: str, mapper: Mapper,
report_folder: str = None) -> NoReturn:
"""
Save source codes to target folder.

Args:
report_folder (str): Report folder.
mapper (Mapper): Mapper of third party framework and mindspore.
out_folder (str): Output folder.

"""
self._adjust_structure()

code_fragments = self._generate_codes(mapper)

out_folder = os.path.abspath(out_folder)
if not report_folder:
report_folder = out_folder
else:
report_folder = os.path.abspath(report_folder)

if not os.path.exists(out_folder):
os.makedirs(out_folder)
if not os.path.exists(report_folder):
os.makedirs(report_folder)

for file_name in code_fragments:
code, report = code_fragments[file_name]
with open(os.path.join(os.path.abspath(out_folder),
f"{file_name}.py"), "w") as file:
file.write(code)

with open(os.path.join(report_folder,
f"report_of_{file_name}.txt"), "w") as rpt_f:
rpt_f.write(report)

def _preprocess_node_args(self, node, module_key):
"""
Remove unused args.

Args:
node (Node): Node instance.
module_key (str): Nodule key.

Returns:
Node, node.
"""
if module_key in self._merged_module_args:
node = self._clear_unused_args(node, self._merged_module_args[module_key])
else:
node.data.clear_args_of_declaration()
return node

def _postprocess_node_args(self, node, precursor_module_key):
"""
Post process args in node.

Args:
node (Node): Node instance.
precursor_module_key (str): Parent node module name.

Returns:
Node, node.
"""
if node.data.node_type == NodeType.MODULE.value:
# If current node is class or function, then
# remove unused args in __init__.
cur_module_key = node.data.hash_key or self.hash_key(node)
if cur_module_key in self._merged_module_args:
node = self._clear_unused_args(node,
self._merged_module_args[cur_module_key])

if precursor_module_key in self._merged_module_args:
# If parent node is in `_merged_module_args`, then
# replace current node args with arg name declared
# in _merged_module_args.
for arg in node.data.args_in_code.keys():
if arg in self._merged_module_args[precursor_module_key]:
node.data.replace_with_arg(arg)
return node

@staticmethod
def _clear_unused_args(node, used_args):
"""
Clear unused args.

Args:
node (Node): Node.
used_args (list): Args list.

Returns:
Node, node instance.
"""
args_in_code = list(node.data.args_in_code.keys())
for arg in args_in_code:
if arg not in used_args:
node.data.args_in_code.pop(arg)
return node

def _generate_codes(self, mapper):
"""
Generate code files.

- 1. Generate args.
- 2. Merge module.
- 3. Pre-process node args.
- 4. Post-process child node args.
- 5. Generate class/func code.
- 6. Merge code snippets.

Args:
mapper (Mapper): Mapper of third party operation and mindspore.

Returns:
Dict, codes.
"""
code_blocks = [self._get_imported_module()]
depths = sorted(list(self._hierarchical_order.keys()), reverse=True)

for depth in depths:
node_collection = self._hierarchical_order[depth]
for node_name in node_collection:
# Traverse nodes in topological order.
node = self.get_node(node_name)
# 1. Generate args for each node in this level.
if node.data.node_type == NodeType.MODULE.value:
self._create_module_args_and_vars(node, mapper)

# 2. Get nodes can be merged.
self._module_merging(node_collection)

snippets = set()
for node_name in node_collection:
nd_inst = self.get_node(node_name)
if nd_inst.data.node_type != NodeType.MODULE.value:
continue

# Generate hash key for node.
module_key = self.hash_key(nd_inst)

# Get code generation func.
func, node_type = self._fetch_func_and_type(nd_inst)

if module_key in self._created_module:
# If the module has already been created,
# then assign the created module name to current node,
# and delete unused args.
module_name = self._created_module[module_key]
nd_inst.data.froze_node_type_and_module_name(node_type,
module_name)
self._preprocess_node_args(nd_inst, module_key)
continue

module_name = nd_inst.data.module_name
if node_type == NodeType.CLASS.value:
module_name = f"{module_name[0].upper()}{module_name[1:]}"

# After node_type and module_name is frozen,
# then it's unchangeable.
module_name = self._module_mgr.get_name(module_name)
nd_inst.data.froze_node_type_and_module_name(node_type,
module_name)

# 3. Pre-process node args.
nd_inst = self._preprocess_node_args(nd_inst, module_key)
# 4. Post-process child node args.
for scsr_nd_name in nd_inst.successors(self.tree_identifier):
self._postprocess_node_args(self.get_node(scsr_nd_name),
module_key)
# 5. Generate code.
snippets.add(func(nd_inst, nd_inst.data.module_name, module_key))

code_blocks.extend(snippets)

formatted_code, _ = FormatCode("".join(code_blocks),
style_config=CodeFormatConfig.PEP8.value)

return {"model": (formatted_code, "No report content.")}

def _fetch_func_and_type(self, node) -> Union[object, str]:
"""
Generate code snippet.

Args:
node (Node): Node.

Returns:
Union[object, str], code snippet func.
"""

def _is_func():
"""
The correct thought is to check whether have more than one
path in this block.
"""
nonlocal node

tgt_type = {NodeType.MODULE.value,
NodeType.FUNC.value, NodeType.CLASS.value}
md_type_lst = [self.get_node(child).data.node_type
for child in node.successors(self.tree_identifier)]
diff_set = set(md_type_lst) - tgt_type
return not diff_set

if _is_func():
return self._generate_func_snippet, NodeType.FUNC.value
return self._generate_class_snippet, NodeType.CLASS.value

def _generate_func_snippet(self, node, func_name, func_key):
"""
Generate function snippet.

Args:
node (Node): Node inst.

Returns:
str, code snippet.
"""
definition = ""

if func_key.lower() in self._merged_module_args and \
self._merged_module_args[func_key.lower()]:
definition = ", ".join(self._merged_module_args[func_key.lower()])

module_list = []
for node_name in node.successors(self.tree_identifier):
c_nd = self.get_node(node_name)
operator = c_nd.data.op_in_ms or c_nd.data.module_name

if c_nd.data.node_type != NodeType.OPERATION.value:
hash_key = c_nd.data.hash_key or self.hash_key(c_nd)
if hash_key in self._created_module:
operator = self._created_module[hash_key]

args = c_nd.data.args_in_code
if c_nd.data.node_type == NodeType.OPERATION.value and \
not c_nd.data.convert_successful():
args.update({"input_shape": c_nd.data.input_shape,
"output_shape": c_nd.data.output_shape})

# Generate code statement.
expr = ", ".join([f"{k}={v}" for k, v in args.items()])
code_line = f"{operator}({expr})"

module_list.append(code_line)

body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list)
snippet = f"{FIRST_LEVEL_INDENT}module_list = [{NEW_LINE}" \
f"{SECOND_LEVEL_INDENT}{body}{NEW_LINE}" \
f"{FIRST_LEVEL_INDENT}]{NEW_LINE}" \
f"{FIRST_LEVEL_INDENT}return nn.SequentialCell(*module_list)"
definition = f"def {func_name}({definition}):{NEW_LINE}"

# Mark the structure has been created.
self._created_module[func_key.lower()] = func_name

return f"{definition}{snippet}{NEW_LINE * 3}"

def _generate_class_snippet(self, node, class_name, class_key):
"""
Generate class-type code snippet.

Args:
node (Node): Node.

Returns:
str, code snippet.
"""
super_call = f"super({class_name}, self).__init__()"

if class_key.lower() in self._merged_module_args and \
self._merged_module_args[class_key.lower()]:
args = f"{', '.join(self._merged_module_args[class_key.lower()])}"
class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \
f"{args}):" \
f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \
f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}"
else:
class_init = f"{FIRST_LEVEL_INDENT}def __init__(self):{NEW_LINE}{SECOND_LEVEL_INDENT}" \
f"{super_call}{NEW_LINE}{SECOND_LEVEL_INDENT}"

init_block = []
construct_block = []

for idx, node_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(node_name)

# Generate code statement.
init, construct = self._generate_stat(nd_inst, node, idx)

construct_block.append(construct)
init_block.append(init)

class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):{NEW_LINE}{SECOND_LEVEL_INDENT}"
init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block)
csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block)
csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}"

cls_definition = f"class {class_name}(nn.Cell):{NEW_LINE * 2}"

# Mark the structure has been created.
self._created_module[class_key.lower()] = class_name

return f"{cls_definition}" \
f"{class_init}" \
f"{init_body}{NEW_LINE}" \
f"{class_construct}" \
f"{csrt_body}{csrt_rtn}{NEW_LINE * 2}"

def _generate_stat(self, cur_nd_inst, pre_nd_inst, idx):
"""
Generate statements.

Args:
cur_nd_inst (Node): Current node instance.
pre_nd_inst (Node): Precursor node instance.
idx (int): Index of cur node.

Returns:
Tuple[str, str], declare in init and call in construct.
"""

ipt_args_in_construct = "x"
opt_arg_in_construct = "output"

if idx != 0:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1:
# Set opt variable name.
opt_arg_in_construct = cur_nd_inst.data.opt_var_name

declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct,
output_var=opt_arg_in_construct)

return declare, call

@staticmethod
def _get_var_name(s):
"""
Get variable name using scope name.

Args:
s (str): String.

Returns:
str, variable name.
"""
return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0]

def _get_previous_opt_var(self, cur_nd, pre_nd):
"""
Get needed input variable names.

Args:
cur_nd (Node): Current node.
pre_nd (Node): Precursor node.

Returns:
str, needed var names.
"""
ipt_lst = []
if cur_nd.data.node_type == NodeType.OPERATION.value:
for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e)
if e not in pre_nd.successors(self.tree_identifier):
while True:
if p_nd.identifier in pre_nd.successors(self.tree_identifier):
ipt_lst.append(p_nd.data.opt_var_name)
break
pre_nd_name = p_nd.predecessor(self.tree_identifier)
if not pre_nd_name:
ipt_lst.append("x")
break
p_nd = self.get_node(pre_nd_name)
continue

ipt_lst.append(p_nd.data.opt_var_name)
else:
idx = pre_nd.successors(self.tree_identifier).index(cur_nd.identifier) - 1
p_nd = self.get_node(pre_nd.successors(self.tree_identifier)[idx])
ipt_lst.append(p_nd.data.opt_var_name)

return ", ".join(ipt_lst)

def hash_key(self, node):
"""
Generate hash key for each node.

Args:
node (Node): Node.

Returns:
str, hash key.
"""
scsr_topo_order = []
for s in node.successors(self.tree_identifier):
cur_nd = self.get_node(s)
if cur_nd.data.hash_key:
scsr_topo_order.append(cur_nd.data.hash_key)
continue
if cur_nd.data.node_type in {NodeType.MODULE.value,
NodeType.FUNC.value,
NodeType.CLASS.value}:
scsr_topo_order.append(self.hash_key(cur_nd))
continue
unique_key = "->".join(scsr_topo_order)
node.data.hash_key = unique_key
return unique_key

def _module_merging(self, nodes):
"""
Generate sub-module and corresponding params.

Args:
nodes (List[str]): Nodes name.

"""
merged_module = dict()
merged_module_args = dict()
for node_name in nodes:
nd_inst = self.get_node(node_name)
if nd_inst.data.node_type != NodeType.MODULE.value:
continue

module_key = self.hash_key(nd_inst)
if module_key not in merged_module:
merged_module[module_key] = [nd_inst.data.args_in_code]
else:
merged_module[module_key].append(nd_inst.data.args_in_code)

for module_key, module_args in merged_module.items():
if module_key not in merged_module_args:
merged_module_args[module_key] = []
# Take first element's args as base.
keys = module_args[0].keys()
for key in keys:
for i in range(1, len(module_args)):
if module_args[0][key] != module_args[i][key]:
merged_module_args[module_key].append(key)
break

self._merged_module.update(merged_module)
self._merged_module_args.update(merged_module_args)

def _create_module_args_and_vars(self, node, mapper):
"""
Create module args.

Args:
node (Node): Node on tree.
mapper (Mapper): Mapper of params.

"""
module_args = dict()
module_key = self.hash_key(node)

created = False

if module_key not in self._args_mgr_in_module:
self._args_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR
self._module_vars[module_key] = []
else:
created = True

for idx, successor_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(successor_name)
# Generate variable name here, then
# to generate args.
# if nd_inst.data.node_type == NodeType.OPERATION.value:
if created:
nd_inst.data.variable_name = self._module_vars[module_key][idx]
else:
variable_name = nd_inst.data.op_name or nd_inst.data.module_name
variable_name = self._args_mgr_in_module[module_key].get_name(variable_name)
nd_inst.data.variable_name = variable_name

if nd_inst.data.node_type == NodeType.OPERATION.value:
# Generation of params must behind variable assigment.
nd_inst.data.param_transform(mapper)

module_args.update(nd_inst.data.args_in_code)

if not created:
self._module_vars[module_key].append(nd_inst.data.variable_name)

node.data.args_in_code = module_args

@staticmethod
def _create_operation_args(node, mapper):
"""
Create operation args.

Args:
node (Node): Node on tree.
mapper (Mapper): Mapper of params.

"""
node.data.param_transform(mapper)

def update_hierarchical_order(self) -> NoReturn:
"""
Update hierarchical order.
"""
hierarchical_order = dict()
queue = Queue()
queue.put(item=(self.root, self.ROOT_LEVEL), block=False)
while not queue.empty():
node_name, cur_level = queue.get(block=False)
node_inst = self[node_name]
if cur_level not in hierarchical_order:
hierarchical_order[cur_level] = []
hierarchical_order[cur_level].append(node_name)
for successor_name in node_inst.successors(self.tree_identifier):
queue.put(item=(successor_name, cur_level + 1), block=False)
self._hierarchical_order = hierarchical_order

def sub_graph_merging(self) -> NoReturn:
"""
Shrink subtree.
"""
self.update_hierarchical_order()
depths = sorted(list(self._hierarchical_order.keys()), reverse=True)
for depth in depths:
for node_name in self._hierarchical_order[depth]:
node_inst = self[node_name]
if not node_inst.data and len(node_inst.successors(self.tree_identifier)) == 1:
self.shrink(node_inst)

def _adjust_structure(self) -> NoReturn:
"""
Adjust tree structure to generate source code.
"""
self.sub_graph_merging()
self.update_hierarchical_order()

@staticmethod
def _get_imported_module():
"""
Generate imported module header.

Returns:
str, imported module.
"""
return f"from mindspore import nn{NEW_LINE}" \
f"from mindspore.ops import operations as P{NEW_LINE * 3}"

+ 98
- 0
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py View File

@@ -0,0 +1,98 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Name manager."""
import abc


class NameMgr(abc.ABC):
"""Module name manager."""
PLACEHOLDER = 1

def __init__(self):
self.record = dict()
self.topo_order = []

def get_name(self, old_name):
"""
Get module/variable name.

If the module already existed, then add a suffix to it.

Args:
old_name (str): Name.

Returns:
str, module name.
"""
if old_name not in self.record:
self.record[old_name] = [self.PLACEHOLDER]
suffix = ""
else:
self.record[old_name].append(self.PLACEHOLDER)
suffix = f"{len(self.record[old_name]) - 1}"

new_name = f"{old_name}{suffix}"
self.topo_order.append(new_name)

return new_name


class ModuleNameMgr(NameMgr):
"""Module name manager."""


class VariableNameMgrInModule(NameMgr):
"""Variable name mgr for a module."""


global_op_namespace = dict()
START_IDX = 0


class GlobalVarNameMgr:
"""Global variable name mgr."""

@staticmethod
def _get_name(name):
"""Deal with op name."""
if "::" in name:
return name.split("::")[1]
return name

def get_name(self, op_type):
"""
Get module/variable name.

If the module already existed, then add a suffix to it.

conv1 onnx::conv

Args:
op_type (str): Operator type in onnx.

Returns:
str, module name.
"""
op_type = op_type.lower()
if op_type not in global_op_namespace:
global_op_namespace[op_type] = START_IDX
suffix = ""
else:
global_op_namespace[op_type] += 1
suffix = f"{global_op_namespace[op_type] - 1}"

new_name = f"{self._get_name(op_type)}{suffix}"

return new_name

+ 20
- 0
mindinsight/mindconverter/graph_based_converter/mapper/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from .base import ONNXToMindSporeMapper

__all__ = [
"ONNXToMindSporeMapper"
]

+ 118
- 0
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -0,0 +1,118 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
import abc
import importlib
import json
import os
from typing import Dict

CONFIG_JSON = "onnx_to_ms.json"
OPERATION_TABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
CONFIG_JSON
)

with open(OPERATION_TABLE) as file:
# Load mapping table which key is operation name in ONNX and
# value is corresponding module path.
TABLE = json.load(file)

# Define global func name.
GET_OP_NAME = "_operation_name_in_ms"
GET_OP_PARAMS = "_convert_params"
GET_OP_WEIGHTS = "_convert_trained_weights"


class Mapper(metaclass=abc.ABCMeta):
"""Mapper between third-party-operation and MindSpore."""

@staticmethod
@abc.abstractmethod
def _operation_name_in_ms():
"""Corresponding operation name in mindspore."""

@staticmethod
@abc.abstractmethod
def _convert_params(params):
"""Convert third party operation's param into MindSpore operation."""

@staticmethod
@abc.abstractmethod
def _convert_trained_weights(weights):
"""Convert third party operation's weights into MindSpore operation."""

@classmethod
@abc.abstractmethod
def convert(cls, op_name: str, params: Dict, weights: Dict = None):
"""Convert third party operation's param into MindSpore operation."""


class ONNXToMindSporeMapper(Mapper, abc.ABC):
"""ONNX operation to MindSpore."""

@classmethod
def convert(cls, op_name: str, params: Dict, weights: Dict = None):
"""
Convert third party operation's param into MindSpore operation.

Args:
op_name (str): Operation name in ONNX.
params (dict): Params in onnx.
weights (dict): Weights in onnx.

Returns:
Tuple[str, dict], operation name and params.
"""
global TABLE
module_name = TABLE.get(op_name)

if not module_name:
return None, dict()

pos = module_name.rfind(".")
try:
converter = getattr(importlib.import_module(module_name[:pos]),
module_name[pos + 1:])
op_name_converter = getattr(converter, GET_OP_NAME)
params_converter = getattr(converter, GET_OP_PARAMS)
weights_converter = getattr(converter, GET_OP_WEIGHTS)
except (ModuleNotFoundError,) as e:
# If mapper can not be found, then skip it.
print(f"Converting {op_name} failed, see {e}")
return None, dict()

try:
converter_name = op_name_converter()
converted_params = params_converter(params)
converted_weights = weights_converter(weights) if weights else dict()
converted_params.update(converted_weights)
except (AttributeError,) as _:
print(f"Converting {op_name} failed.")
return None, dict()

return converter_name, converted_params

@staticmethod
def _operation_name_in_ms():
raise NotImplementedError

@staticmethod
def _convert_params(params):
raise NotImplementedError

@staticmethod
def _convert_trained_weights(weights):
raise NotImplementedError

+ 15
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implemented mapper module."""

+ 15
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implemented mapper."""

+ 39
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/avg_pool_mapper.py View File

@@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class GlobalAvgPoolMapper(ONNXToMindSporeMapper):
"""AvgPool mapper."""

@staticmethod
def _operation_name_in_ms():
return "nn.AvgPool2d"

@staticmethod
def _convert_params(params):
kernel_size_height = params['input_shape'][2] // params['output_shape'][2]
kernel_size_width = params['input_shape'][3] // params['output_shape'][3]
kernel_size = [kernel_size_height, kernel_size_width]
return {
'kernel_size': kernel_size
}

@staticmethod
def _convert_trained_weights(weights):
if weights:
pass
return dict()

+ 38
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class BatchNormMapper(ONNXToMindSporeMapper):
"""BatchNorm mapper."""

@staticmethod
def _operation_name_in_ms():
return "nn.BatchNorm2d"

@staticmethod
def _convert_params(params):
return {
'num_features': params['input_shape'][1],
'eps': params['epsilon'],
'momentum': params['momentum']
}

@staticmethod
def _convert_trained_weights(weights):
if weights:
pass
return dict()

+ 41
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv2d_mapper.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class Conv2dMapper(ONNXToMindSporeMapper):
"""Conv2d mapper."""

@staticmethod
def _operation_name_in_ms():
return "nn.Conv2d"

@staticmethod
def _convert_params(params):
return {
'in_channels': params['input_shape'][1],
'out_channels': params['output_shape'][1],
'kernel_size': params['kernel_shape'],
'stride': params['strides'][0],
'pad': params['pads'][0],
'dilation': params['dilations'][0],
'group': params['group']}

@staticmethod
def _convert_trained_weights(weights):
if weights:
pass
return dict()

+ 37
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class DenseMapper(ONNXToMindSporeMapper):
"""Dense mapper."""

@staticmethod
def _operation_name_in_ms():
return "nn.Dense"

@staticmethod
def _convert_params(params):
return {
'in_channels': params['input_shape'][1],
'out_channels': params['output_shape'][1]
}

@staticmethod
def _convert_trained_weights(weights):
if weights:
pass
return dict()

+ 36
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/flatten_mapper.py View File

@@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class FlattenMapper(ONNXToMindSporeMapper):
"""Flatten mapper."""

@staticmethod
def _operation_name_in_ms():
return "nn.Flatten"

@staticmethod
def _convert_params(params):
if params:
pass
return dict()

@staticmethod
def _convert_trained_weights(weights):
if weights:
pass
return dict()

+ 37
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/max_pool_mapper.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class MaxPoolMapper(ONNXToMindSporeMapper):
"""MaxPool mapper."""

@staticmethod
def _operation_name_in_ms():
return "nn.MaxPool2d"

@staticmethod
def _convert_params(params):
return {
'kernel_size': params['kernel_shape'],
'stride': params['strides']
}

@staticmethod
def _convert_trained_weights(weights):
if weights:
pass
return dict()

+ 36
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/relu_mapper.py View File

@@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class ReLUMapper(ONNXToMindSporeMapper):
"""ReLU mapper."""

@staticmethod
def _operation_name_in_ms():
return "nn.ReLU"

@staticmethod
def _convert_params(params):
if params:
pass
return dict()

@staticmethod
def _convert_trained_weights(weights):
if weights:
pass
return dict()

+ 15
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implemented mapper."""

+ 36
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py View File

@@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper


class AddMapper(ONNXToMindSporeMapper):
"""Add mapper."""

@staticmethod
def _operation_name_in_ms():
return "P.TensorAdd"

@staticmethod
def _convert_params(params):
if params:
pass
return dict()

@staticmethod
def _convert_trained_weights(weights):
if weights:
pass
return dict()

+ 10
- 0
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -0,0 +1,10 @@
{
"onnx::Conv": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.conv2d_mapper.Conv2dMapper",
"onnx::Gemm": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.dense_mapper.DenseMapper",
"onnx::BatchNormalization": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.batch_norm_mapper.BatchNormMapper",
"onnx::Relu": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper",
"onnx::MaxPool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.max_pool_mapper.MaxPoolMapper",
"onnx::GlobalAveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.avg_pool_mapper.GlobalAvgPoolMapper",
"onnx::Flatten": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.flatten_mapper.FlattenMapper",
"onnx::Add": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.add_mapper.AddMapper"
}

+ 46
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Graph associated definition module."""
from .base import Graph
from .pytorch_graph import PyTorchGraph
from .pytorch_graph_node import PyTorchGraphNode


class GraphFactory:
"""Graph factory."""

@classmethod
def init(cls, graph_path: str, sample_shape: tuple, checkpoint: str = None):
"""
Init an instance of graph.

Args:
graph_path (str): Graph or model file path.
sample_shape (tuple): Input shape of the model.
checkpoint (str): Checkpoint file path.

Returns:
Graph, graph instance.
"""
if checkpoint:
pass

return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape)


__all__ = [
"GraphFactory",
"PyTorchGraphNode",
]

+ 547
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -0,0 +1,547 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define graph entity."""
import abc
from collections import OrderedDict
from typing import Dict, Union, Any

from torch.nn import Module

from ..constant import SEPARATOR_IN_ONNX_OP
from ..mapper.base import Mapper


class GraphParser(metaclass=abc.ABCMeta):
"""Graph parser."""

@classmethod
@abc.abstractmethod
def parse(cls, model_path: str):
"""Parse graph into readable format."""


class BaseGraph(metaclass=abc.ABCMeta):
"""Define basic graph."""
_REQUIRED_PARAM_OF_MODEL = "model"

@abc.abstractmethod
def build(self, input_shape: tuple):
"""Build graph."""

@abc.abstractmethod
def to_ir(self, mapper):
"""Convert graph to ir graph."""

@abc.abstractmethod
def to_hierarchical_tree(self):
"""Convert to hierarchical tree."""

@abc.abstractmethod
def sub_graph_merging(self):
"""Merge split nodes into one."""

@staticmethod
@abc.abstractmethod
def load_checkpoint(ckpt_path: str) -> Dict:
"""Load checkpoint file."""

@staticmethod
@abc.abstractmethod
def load_metadata(**kwargs):
"""Load graph metadata."""

@staticmethod
@abc.abstractmethod
def load_graph(graph_path: str):
"""Load graph file."""

@classmethod
@abc.abstractmethod
def load(cls, model_path: str, sample_shape: tuple = None,
checkpoint: str = None):
"""Factory method to initialize an graph object."""

def __new__(cls, *args, **kwargs):
"""Control the create action of graph."""
model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL)
if not model_param:
raise ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` "
f"can not be None.")

return super(BaseGraph, cls).__new__(cls)


class Graph(BaseGraph, abc.ABC):
"""
Define Factory method to create Graph sub-class.

Args:
model (Union[Module, Any]): Graph file.
checkpoint (dict): Checkpoint path.

"""

sorted = False

def __init__(self, model: Union[Module, Any],
**kwargs):
super(Graph, self).__init__()
self.model = model
self.checkpoint = kwargs.get("checkpoint", None)
self._nodes_collection = OrderedDict()
self._nodes_record = dict()
self._shape_dict = dict()
self._input_nodes = []
self._output_nodes = []
self._topological_order = []
self._input_shape = dict()

@property
def nodes_in_topological_order(self):
"""
Return nodes in topological order.

Returns:
List[GraphNode], nodes.
"""
if not self.sorted:
self._topological_sort()
return self._topological_order

def _reset_topological_order(self):
"""
Reset topological order queue.
"""
self._topological_order = self._input_nodes[:]
self.sorted = False

def get_node(self, node_name):
"""
Get node reference.

Args:
node_name (str): Node name.

Returns:
GraphNode, node instance.
"""
prefix = node_name.split(":")[0]
if prefix not in self._nodes_collection:
return None
return self._nodes_collection[prefix]

def build(self, input_shape: tuple):
"""
Build graph.

Args:
input_shape (tuple): Input shape of model.

"""
# Collect input nodes and output nodes.
self._collect_ipt_and_opt_nodes()
# Use topological sort to solve nodes order.
self._topological_sort()

def _collect_ipt_and_opt_nodes(self):
"""
Collect input and output nodes in model.
"""
for name, node in self._nodes_collection.items():
if node.in_degree == 0:
# NOTICE: what's usage of `scope`?
self._input_nodes.append(name)

if node.out_degree == 0:
self._output_nodes.append(name)

def _topological_sort(self):
"""Topological sort to arrange nodes order."""
self._reset_topological_order()

def is_connected(src, dst):
"""Judge two node whether are connected."""
for precursor in dst.precursor_nodes:
if src == precursor.split(":")[0]:
return 1
return 0

idx = 0
while idx < len(self._topological_order):
cur_node_name = self._topological_order[idx]
cur_node = self.get_node(cur_node_name)
# `scsr` is abbreviation for `successor`.
for scsr_name in cur_node.successor_nodes:
scsr_node = self.get_node(scsr_name)
scsr_node.cur_in_degree -= is_connected(cur_node_name,
scsr_node)
if scsr_node.cur_in_degree == 0:
self._topological_order.append(scsr_name)
idx += 1
self.sorted = True

def to_ir(self, mapper):
raise NotImplementedError

def to_hierarchical_tree(self):
raise NotImplementedError

def sub_graph_merging(self):
raise NotImplementedError

@staticmethod
def load_checkpoint(ckpt_path: str) -> Dict:
raise NotImplementedError

@staticmethod
def load_metadata(**kwargs):
raise NotImplementedError

@staticmethod
def load_graph(graph_path: str):
raise NotImplementedError

@classmethod
def load(cls, model_path: str, sample_shape: tuple = None,
checkpoint: str = None) -> BaseGraph:
"""
Load third party graph, metadata and checkpoint.

Notes:
`checkpoint` is optional, and it can not be supported currently.

Args:
model_path (str): Graph or model file path.
sample_shape (tuple): Input shape of the model.
checkpoint (str): Checkpoint file path.

Returns:
cls, graph instance.
"""
src_graph = cls.load_graph(graph_path=model_path)
ckpt = cls.load_checkpoint(ckpt_path=checkpoint) if checkpoint else None

if ckpt is not None:
# Create an instance of TensorflowGraph.
return cls(model=src_graph, sample_shape=sample_shape,
checkpoint=ckpt)

# Create an instance of PyTorchGraph.
return cls(model=src_graph, sample_shape=sample_shape)


class GraphNode(abc.ABC):
"""
Graph node.

Args:
node (torch._C.Node): PyTorch node.

"""
transformed = False

def __init__(self, node):
# Store the edge from precursor.
self.precursor_nodes = []
# Store the edge to successor.
self.successor_nodes = []
# Control dependency.
self._deleted_in_edge = 0
# Source node in pytorch.
self._src_node = str(node) if node else None
# Original operation name in pytorch.
self._op_name = None
self._op_params = dict()
self._scope_name = None
self._op_shape = None
# Operation in mindspore.
self._op_in_ms = None
# Params in mindspore.
self._params_in_ms = dict()
# Node type of current node, e.g. class, module, operation.
self._node_type = None
# Tag name on tree.
self._tag_on_tree = None
# Function, class or operation needed args.
self._args_in_code = dict()
# Variable name declared in init block.
self._variable_name = None
# Output variable name declared in construct block.
self._opt_var_name = None
# Function or class name in code.
self._module_name = None
# Unique key of node.
self._hash_key = None
# Input shape of current op.
self._ipt_shape = None
# Output shape of current op.
self._opt_shape = None

@property
def opt_var_name(self):
"""
Output variable name.

Returns:
str, variable name.
"""
return f"{self.variable_name}_opt"

@opt_var_name.setter
def opt_var_name(self, v):
"""
Set variable name.

Args:
v (str): Name.

"""
self._opt_var_name = v

@property
def op_in_ms(self):
"""
Operation in mindspore.

Returns:
str, operation name.
"""
if self._op_in_ms and SEPARATOR_IN_ONNX_OP in self._op_in_ms:
return self._op_in_ms.replace(SEPARATOR_IN_ONNX_OP, ".")
return self._op_in_ms

@property
def args_in_code(self):
"""
Args in code.

Returns:
dict, args.
"""
return self._args_in_code

@args_in_code.setter
def args_in_code(self, args):
"""
Setter for args_in_code.

Args:
args (dict): Args.

"""
self._args_in_code = args

@property
def input_shape(self):
"""
Input tensor shape of current node.

Returns:
tuple, tensor shape of input.
"""
return self._ipt_shape

@property
def output_shape(self):
"""
Output tensor shape.

Returns:
tuple, output tensor shape.
"""
return self._opt_shape

@property
def tag(self):
"""Tag on hierarchical tree."""
return self._tag_on_tree

@tag.setter
def tag(self, t):
"""Tag on hierarchical tree."""
self._tag_on_tree = t

def is_empty(self):
"""
Whether is empty.

Returns:
bool, true or false.
"""
return not self._src_node

@property
def node_type(self):
"""Get node type (ONNX op type)."""
return self._node_type

@node_type.setter
def node_type(self, m):
"""
Setter of node_type.

Args:
m (str): Node type.

"""
self._node_type = m

@property
def scope_name(self):
"""
Scope name.

Returns:
str, scope name.
"""
return self._scope_name

@property
def node_params(self):
"""Get node params (ONNX op params)."""
return self._op_params

@property
def cur_in_degree(self):
"""
Current in-degree.

Returns:
int, current in-degree.
"""
return self.in_degree - self._deleted_in_edge

@cur_in_degree.setter
def cur_in_degree(self, e):
"""
Setter of cur_in_degree.

Args:
e (int): To be update value.

"""
self._deleted_in_edge += self.cur_in_degree - e

@property
def in_degree(self):
"""
Define in-degree.

Returns:
int, in-degree.
"""
return len(self.precursor_nodes)

@property
def out_degree(self):
"""
Define out-degree.

Returns:
int, out-degree.
"""
return len(self.successor_nodes)

@property
@abc.abstractmethod
def hash_key(self):
"""
Generate unique hash key for each node.

Use topological order as key.
"""

@abc.abstractmethod
def _get_raw_params(self, node):
"""Get params in onnx."""

@property
@abc.abstractmethod
def op_name(self):
"""Return op_name."""

@abc.abstractmethod
def replace_with_arg(self, arg):
"""Replace actual parameter with formal parameter."""

@abc.abstractmethod
def _get_arg_name(self, arg):
"""Get arg name for func or class."""

@abc.abstractmethod
def clear_args_of_declaration(self):
"""Clear `_args_in_code`."""

@property
@abc.abstractmethod
def real_name(self):
"""Getter of `real_name`."""

@real_name.setter
@abc.abstractmethod
def real_name(self, **kwargs):
"""Setter of `real_name`."""

@property
@abc.abstractmethod
def variable_name(self):
"""Getter of `variable_name`."""

@abc.abstractmethod
def to_code(self, ipt_args_in_construct: str, output_var: str):
"""Graph node to MindSpore code."""

@abc.abstractmethod
def to_ir(self):
"""Graph node to ir node."""

@abc.abstractmethod
def add_input_and_output_shape(self, input_shape, output_shape):
"""Add the node input shape."""

@abc.abstractmethod
def froze_node_type_and_module_name(self, node_type, module_name):
"""Make node_type can not be changed."""

@abc.abstractmethod
def convert_successful(self):
"""Whether convert successful."""

def param_transform(self, mapper: Mapper):
"""
Transform param in pytorch operation into mindspore.

Args:
mapper (ONNXToMindSporeMapper): Mapper between onnx operation
and mindspore.

Returns:
dict, transformed params.
"""
import copy
params = copy.deepcopy(self._op_params)
params.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

op_name_in_mindspore, ms_params = mapper.convert(op_name=self.op_name,
params=params)
if op_name_in_mindspore:
self._op_in_ms = op_name_in_mindspore
self._params_in_ms = ms_params
else:
self._op_in_ms = self._op_name
self._params_in_ms = self._op_params

return self._op_in_ms, self._params_in_ms

+ 45
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py View File

@@ -0,0 +1,45 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Third party graph parser."""
import os
from .base import GraphParser


class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""

@classmethod
def parse(cls, model_path: str):
"""
Parser pytorch graph.

Args:
model_path (str): Model file path.

Returns:
object, torch model.
"""
import torch

if not os.path.exists(model_path):
raise FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")

if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")

return model

+ 109
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py View File

@@ -0,0 +1,109 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph node."""
import os

from .base import GraphNode
from ..constant import SEPARATOR_IN_SCOPE, NodeType


class InputNode(GraphNode):
"""
Pytorch Input Node.

Args:
input_shape: Input shape of module.

"""

def convert_successful(self):
"""
Whether convert successful.

Returns:
bool, true or false.
"""
return False

def froze_node_type_and_module_name(self, node_type, module_name):
pass

def _get_raw_params(self, node):
pass

def clear_args_of_declaration(self):
pass

@property
def op_name(self):
return self._op_name

def hash_key(self):
pass

def replace_with_arg(self, arg):
pass

def _get_arg_name(self, arg):
pass

def add_input_and_output_shape(self, input_shape, output_shape):
pass

def __init__(self, input_shape):
super(InputNode, self).__init__(node=None)
self._op_name = 'Input'
self._op_params = {'node_shape': input_shape}
self._node_type = NodeType.INPUT.value

def set_scope_name(self, original_input_scope_name):
"""
Set scope name.
Args:
original_input_scope_name: Original input scope name needed to be linked.
"""
prefix_name = original_input_scope_name.split(SEPARATOR_IN_SCOPE)[0]
node_name = ''.join((self.node_type, '[input]'))
self._scope_name = os.path.join(prefix_name, node_name)

def set_successor_nodes(self, original_input_scope_names):
"""
Set successor nodes.
Args:
original_input_scope_names: Original input scope names needed to be linked.
"""
if isinstance(original_input_scope_names, list):
self.successor_nodes = original_input_scope_names
elif isinstance(original_input_scope_names, str):
self.successor_nodes.append(original_input_scope_names)
else:
raise ValueError

@property
def real_name(self):
return

@property
def variable_name(self):
return

def to_ir(self):
"""
No need to implement for now.
"""
raise NotImplementedError()

def to_code(self, ipt_args_in_construct: str, output_var: str):
raise NotImplementedError()

+ 268
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -0,0 +1,268 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph."""
import platform
import warnings
import re
from typing import Dict, NoReturn

import torch
from torch.nn import Module
from torch.onnx import OperatorExportTypes

from .base import Graph
from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode
from .graph_parser import PyTorchGraphParser
from .torch_utils import OverloadTorchModuleTemporarily, unique_state_dict
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer
from ..hierarchical_tree import HierarchicalTree
from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE
from ..constant import LEFT_BUCKET, RIGHT_BUCKET

NONE_SCOPE_OP = {
'onnx::Add': 'Add',
'onnx::Flatten': 'Flatten',
}


def normalize_scope_name(node):
"""
Rename scope name into uniform.

Args:
node (Node): PyTorch node.

Returns:
str, normalized scope name.
"""
global NONE_SCOPE_OP

name = node.scopeName().split(SEPARATOR_IN_SCOPE)
scopes = []
for segment in name:
segment = segment.split(LINK_IN_SCOPE)[0]
left = segment.find(LEFT_BUCKET)
right = segment.find(RIGHT_BUCKET)
if left != -1:
if segment[left + 1: right].isdigit():
scopes.append(f"{segment[:left]}_{segment[left + 1: right]}")
else:
scopes.append(segment[left + 1: right])
else:
scopes.append(segment)
if node.kind() in NONE_SCOPE_OP.keys():
scopes.append(NONE_SCOPE_OP[node.kind()])
return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}"


class PyTorchGraph(Graph):
"""
Define PyTorch graph.

Args:
model (Module): PyTorch model.
sample_shape (tuple): Input shape of the model.

"""

def __init__(self, model: Module, sample_shape: tuple):
super(PyTorchGraph, self).__init__(model=model)
self._params_dict = unique_state_dict(model)
self.build(sample_shape)

@staticmethod
def _check_input_shape(input_shape):
"""
Check input shape.

Args:
input_shape (tuple): Input tensor shape.

"""
if not input_shape:
raise ValueError("`input_shape` can not be None.")

for item in input_shape:
if not isinstance(item, int):
raise ValueError(f"Only support model with one input now, "
f"and each shape value in `input_shape` should be int.")

def build(self, input_shape):
"""
Build graph tree.

Args:
input_shape (tuple): Input shape of model.

"""
self._check_input_shape(input_shape)

def _extract_shape(shape):
if platform.system() == "Darwin":
return [int(x.split(":")[0]) for x in shape.split(',')]
return [int(x.replace("!", "")) for x in shape.split(',')]

feed_forward_ipt_shape = (1, *input_shape)
batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape))

# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)

nodes = list(graph.nodes())

for node in nodes:
node_name = normalize_scope_name(node)
output_shape_str_list = re.findall(r'[^()!]+', str(node))
output_shape_str = output_shape_str_list[1]
output_shape = _extract_shape(output_shape_str)

self._shape_dict[node_name] = output_shape
self._nodes_collection[node_name] = PyTorchGraphNode(node)
self._nodes_record[node_name] = node_name

for node_input in list(node.inputs()):
# Connect input node and src node.
if PyTorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName():
node_input_name = normalize_scope_name(
node_input.node()
)
self.build_connection(node_input_name, node_name)

super(PyTorchGraph, self).build(input_shape=input_shape)

# Add Input Node
input_node = InputNode(input_shape)
for node_name, node in self._nodes_collection.items():
if node_name in self._input_nodes:
input_node.set_scope_name(node.scope_name)
node.precursor_nodes.append(input_node.scope_name)
input_node.set_successor_nodes(node_name)
self._nodes_collection[input_node.scope_name] = input_node
self._input_shape[node_name] = feed_forward_ipt_shape
break

def sub_graph_merging(self):
"""
Merge split operation into one.
"""
raise NotImplementedError()

def to_ir(self, mapper):
"""
Convert graph to IR graph.
"""
raise NotImplementedError()

def to_hierarchical_tree(self):
"""
Generate hierarchical tree based on graph.
"""
tree = HierarchicalTree()
node_input = None
for _, node_name in enumerate(self.nodes_in_topological_order):
node_inst = self.get_node(node_name)
node_output = self._shape_dict.get(node_name)
if node_inst.in_degree == 0:
# If in-degree equals to zero, then it's a input node.
continue

# If the node is on the top, then fetch its input
# from input table.
if not node_input:
node_input = self._input_shape.get(node_name)

if not node_input:
raise ValueError(f"Cannot find {node_name}'s input shape.")

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output
return tree

def build_connection(self, src, tgt) -> NoReturn:
"""
Build connection between source node and target node.

Args:
src (str): Source node name.
tgt (str): Target node name.

"""
# If src and tgt are the same node, src not in node_collection or
# tgt not in node_collection,
# then skip this edge.
if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection:
if src.split(':')[0] not in self._nodes_collection:
warnings.warn(f"Graph construct a self-loop node {src}. Ignored.")
return

if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes:
self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt)
if src not in self._nodes_collection[tgt].precursor_nodes:
self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src)

@staticmethod
def load_checkpoint(ckpt_path: str) -> Dict:
"""
Load checkpoint.

Args:
ckpt_path (str): Checkpoint file path.

Returns:
dict, weights in model.
"""

@staticmethod
def load_metadata(**kwargs):
"""
Load graph metadata.
"""
raise NotImplementedError("class `PyTorchGraph` has not implemented "
"`load_metadata()`.")

@staticmethod
def load_graph(graph_path: str):
"""
Load graph.

Args:
graph_path (str): Graph path.

Returns:
object, pytorch model.
"""
torch_model = PyTorchGraphParser.parse(graph_path)
return torch_model

@staticmethod
def get_node_id(node):
"""
Get node id using regular expr.

Args:
node (Node): PyTorch node.

Returns:
str, node id.
"""
node_id = re.search(r"[\d]+", str(node))
return node_id.group()

+ 282
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -0,0 +1,282 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph node."""
from .base import GraphNode
from .torch_utils import getitem_of_node
from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP
from ..mapper.base import Mapper


class PyTorchGraphNode(GraphNode):
"""
PyTorch graph node.

Args:
node (torch._C.Node): Node in raw PyTorch graph.

"""

_type_frozen = False
_module_name_frozen = False

def __init__(self, node=None):
super(PyTorchGraphNode, self).__init__(node=node)
self._op_params = self._get_raw_params(node)
self._op_name = node.kind() if node else None
self._scope_name = node.scopeName() if node else None
self._opt_var_name = None
self._variable_name = self._extract_var_name(self._scope_name)
self._module_name = None

def clear_args_of_declaration(self):
"""
Clear `self._args_in_code`.
"""
self._args_in_code = dict()

def _get_arg_name(self, arg):
"""
Get arg name.

Args:
arg (str): Generate arg name.

Returns:
str, arg name in function or class declaration.
"""
return f"{arg}_{self._variable_name}"

@property
def hash_key(self):
"""
Return unique hash key of current node.

Returns:
str, hash key.
"""
if self._node_type not in {NodeType.CLASS.value,
NodeType.FUNC.value,
NodeType.MODULE.value}:
self._hash_key = self._op_name.lower()
return self._hash_key

@hash_key.setter
def hash_key(self, h):
"""
Setter of hash key.

Args:
h (str): Key.

"""
self._hash_key = h

@property
def variable_name(self):
"""
Variable name.

Returns:
str, variable name declared in init.
"""
return self._variable_name

@variable_name.setter
def variable_name(self, v):
"""
Setter of variable name.

Args:
v (str): Variable name.

"""
self._variable_name = v

@property
def module_name(self):
"""
Module name.

Returns:
str, module name.
"""
if not self._module_name_frozen:
module_name = self.tag
# if self._node_type == NodeType.CLASS.value:
# module_name = f"{module_name[0].upper()}{module_name[1:]}"
return module_name

return self._module_name

def _froze_module_name(self, m):
"""
Once module_name is set, then it's unchangeable.

Args:
m (str): Module name.

"""
if not self._module_name_frozen:
self._module_name = m
self._module_name_frozen = True

@property
def op_name(self):
"""
Op name in torch.

Returns:
str, op name.
"""
return self._op_name # if self.is_empty() else self.tag

@property
def real_name(self):
return

def add_input_and_output_shape(self, input_shape, output_shape):
"""
Add the node input shape.

Args:
output_shape (tuple): Output tensor shape.
input_shape (tuple): Input tensor shape.

"""
self._ipt_shape = input_shape
self._opt_shape = output_shape

def to_code(self, ipt_args_in_construct: str, output_var: str):
"""
Generate statements.

Args:
ipt_args_in_construct (str): Args of input.
output_var (str): Output variable name in construct.

Returns:
Union[str, str], declare in init and call in construct.
"""
operator = self.op_in_ms or self.module_name
self._opt_var_name = output_var

args = self.args_in_code
if self._node_type == NodeType.OPERATION.value and not self.convert_successful():
args.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
for k, v in args.items()])
declare = f"self.{self._variable_name} = {operator}({expr})"
call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})"

return declare, call

def to_ir(self):
"""
No need to implement for now.
"""
raise NotImplementedError

def _get_raw_params(self, node):
"""
Get params in onnx.

Args:
node (Any): Node.

Returns:
dict, raw params.
"""
raw_params = dict()

if not node:
return raw_params

for k in node.attributeNames():
raw_params[k] = getitem_of_node(node, k)
return raw_params

def replace_with_arg(self, arg):
"""
Replace actual parameter with formal parameter.

Args:
arg (str): Arg name.

"""
self._args_in_code[arg] = arg

@staticmethod
def _extract_var_name(scope_name: str):
"""
Extract variable name from scope name.
"""
if not scope_name:
return None
var = scope_name.split(SEPARATOR_IN_SCOPE)[-1].lower()
var = var.replace(LEFT_BUCKET, SEPARATOR_BTW_NAME_AND_ID).replace(
RIGHT_BUCKET, "")
return var

def param_transform(self, mapper: Mapper):
"""
Transform torch params into mindspore.

Args:
mapper (Mapper): Mapper of params.

"""
if not self.transformed:
_, _ = super(PyTorchGraphNode, self).param_transform(mapper)

for arg, value in self._params_in_ms.items():
self._args_in_code[self._get_arg_name(arg)] = value

self.transformed = True

return self._op_in_ms, self._params_in_ms

def froze_node_type_and_module_name(self, node_type, module_name):
"""
Froze node type and module name.

After node_type is frozen, then the `module_name`
will be affected when `node_type` is `class`.
Thus, this line must be placed before `nd_inst.data.module_name`.

Args:
module_name: Modified module name.
node_type (str): Node type, class of func.

"""
if not self._type_frozen:
self._node_type = node_type
self._type_frozen = True

if not self._module_name_frozen:
self._froze_module_name(module_name)

def convert_successful(self):
"""
Whether convert successfully.

Returns:
bool, true or false.
"""
if self._op_in_ms and SEPARATOR_IN_ONNX_OP not in self._op_in_ms:
return True
return False

+ 101
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py View File

@@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define pytorch tracer context manager."""
import importlib

from torch.nn import Module
from torch.jit import _unique_state_dict
from torch.onnx.utils import _trace
from torch.onnx.utils import _node_getitem

SCRIPT_METHOD = getattr(importlib.import_module("torch._C"),
"ScriptMethod")
onnx_tracer = _trace
getitem_of_node = _node_getitem


def unique_state_dict(model):
"""
Wrapper of torch.jit._unique_state_dict.

Args:
model (Module): Torch model.

Returns:
dict, params.
"""
return _unique_state_dict(model)


def create_autograd_variable(tensor):
"""
Create autograd variable to trace the whole graph.

Args:
tensor (torch.Tensor): Tensor.

Returns:
torch.autograd.Variable, variable.
"""
variable = getattr(importlib.import_module("torch.autograd"), "Variable")
return variable(tensor, requires_grad=False)


class OverloadTorchModuleTemporarily:
"""
Fix bugs in new version of pytorch.
PyTorch official solution.
"""

def __init__(self):
self.backup = None

def __enter__(self):
def _tracing_name(traced_module, tracing_state):
traced_module_stack = getattr(tracing_state, "_traced_module_stack")
if not traced_module_stack:
return None
module = traced_module_stack[-1]
for name, child in module.named_children():
if child is traced_module:
return name
return None

def _slow_forward(self_, *inputs, **kwargs):
tracing_state = getattr(importlib.import_module("torch._C"),
"_get_tracing_state")()
if not tracing_state or isinstance(self_.forward, SCRIPT_METHOD):
return self_.forward(*inputs, **kwargs)
if not hasattr(tracing_state, '_traced_module_stack'):
tracing_state._traced_module_stack = []
name = _tracing_name(self_, tracing_state)
get_name_func = getattr(self_, "_get_name")
if name:
tracing_state.push_scope('%s[%s]' % (get_name_func(), name))
else:
tracing_state.push_scope(get_name_func())
tracing_state._traced_module_stack.append(self_)
try:
result = self_.forward(*inputs, **kwargs)
finally:
tracing_state.pop_scope()
tracing_state._traced_module_stack.pop()
return result

self.backup = getattr(Module, "_slow_forward")
setattr(Module, '_slow_forward', _slow_forward)

def __exit__(self, exc_type, exc_val, exc_tb):
setattr(Module, '_slow_forward', self.backup)

+ 3
- 1
requirements.txt View File

@@ -14,4 +14,6 @@ psutil>=5.6.1
six>=1.12.0
Werkzeug>=1.0.0
tabulate>=0.8.6
pandas>=1.0.4
pandas>=1.0.4
yapf>=0.30.0
treelib>=1.6.1

Loading…
Cancel
Save