Browse Source

!1016 Add multi-output converter function in PyTorch.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3d80b76841
9 changed files with 247 additions and 27 deletions
  1. +9
    -3
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +5
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +60
    -7
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  4. +12
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py
  5. +41
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py
  6. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json
  7. +23
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  8. +67
    -10
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  9. +28
    -4
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py

+ 9
- 3
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -44,6 +44,7 @@ class Fragment(abc.ABC):
operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values.
settings (namedTuple): Code generation setting.

"""

def __init__(self, operation, actual_args, input_shape, output_shape, settings=None):
@@ -89,9 +90,9 @@ class Fragment(abc.ABC):
self._declared_variable_name = var

@property
def output_var_name(self) -> str:
def output_var_name(self) -> list:
"""Getter of output variable name."""
return ", ".join(self._output_var_name)
return self._output_var_name

@output_var_name.setter
def output_var_name(self, opt_vars):
@@ -100,6 +101,7 @@ class Fragment(abc.ABC):

Args:
opt_vars (list[str]): Output variable name.

"""
self._output_var_name = opt_vars

@@ -119,8 +121,9 @@ class Fragment(abc.ABC):

Args:
ipt (Fragment): Where input comes from.

"""
self._operation_inputs.append(ipt)
self._operation_inputs += ipt

@property
def operation(self):
@@ -139,6 +142,7 @@ class Fragment(abc.ABC):

Args:
op (str): Operation name.

"""
self._operation = op

@@ -158,6 +162,7 @@ class Fragment(abc.ABC):

Args:
formal_args (dict): To be updated args.

"""
return self._formal_args_list.update(formal_args)

@@ -194,6 +199,7 @@ class CodeFragment(Fragment):
operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values.
settings (namedTuple): Code generation setting.

"""

def __init__(self, operation, actual_args, settings, input_shape, output_shape,


+ 5
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -18,6 +18,7 @@ from enum import Enum, unique
SEPARATOR_IN_ONNX_OP = "::"
SEPARATOR_IN_SCOPE = "/"
SEPARATOR_BTW_NAME_AND_ID = "_"
SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "="
LINK_IN_SCOPE = "-"
LEFT_BUCKET = "["
RIGHT_BUCKET = "]"
@@ -52,6 +53,10 @@ EXPECTED_NUMBER = 1

MIN_SCOPE_LENGTH = 2

NO_CONVERTED_OPERATORS = [
"onnx::Constant"
]


@unique
class CodeFormatConfig(Enum):


+ 60
- 7
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report
from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..constant import SEPARATOR_IN_SCOPE, get_imported_module
from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS
from ..constant import CodeFormatConfig
from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
@@ -472,6 +472,8 @@ class HierarchicalTree(Tree):

for idx, node_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(node_name)
if nd_inst.data.op_name in NO_CONVERTED_OPERATORS:
continue

# Generate code statement.
init, construct = self._generate_stat(nd_inst, node, idx)
@@ -518,14 +520,25 @@ class HierarchicalTree(Tree):
"""

ipt_args_in_construct = "x"
opt_arg_in_construct = "output"
opt_arg_in_construct = ["output"]

if idx != 0:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
if cur_nd_inst.data.is_in_multi_opt_graph:
ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst)
else:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1:
# Set opt variable name.
opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt"
if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph:
opt_arg_in_construct = [
f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt"
]
else:
opt_arg_in_construct = [
f"opt_{var_name}"
for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name
]

declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct,
variable_name=self.code_fragment_recorder[
@@ -548,6 +561,39 @@ class HierarchicalTree(Tree):
"""
return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0]

def _get_current_ipt_var(self, cur_nd):
""""
Get current input variable name from node_id.

Args:
cur_nd (Node): Current node.

Returns:
str, needed var names.
"""
if cur_nd.data.node_type != NodeType.OPERATION.value:
while True:
p_nd = cur_nd.successors(self.tree_identifier)
if not p_nd:
break
cur_nd = self.get_node(p_nd[0])

ipt_lst_raw = []
for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs:
ipt_lst_raw.append(f"{operation_input}")

opt_var_names_p_nds = set()
for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e)
if p_nd.data.op_name in NO_CONVERTED_OPERATORS:
continue

opt_var_names_p_nd = set(p_nd.data.opt_var_names)
opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd)

ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)]
return ", ".join(ipt_lst)

def _find_all_previous_opt_var_(self, cur_nd, pre_nd):
"""
Find all input variable names.
@@ -557,9 +603,12 @@ class HierarchicalTree(Tree):
pre_nd (Node): Precursor node.

Returns:
str, needed var names.
list, needed var names list.
"""
ipt_lst = []
if cur_nd.tag in NO_CONVERTED_OPERATORS:
return ipt_lst

for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e)
if e not in pre_nd.successors(self.tree_identifier):
@@ -575,7 +624,6 @@ class HierarchicalTree(Tree):
break
p_nd = self.get_node(pre_nd_name)
continue

ipt_lst.append(
f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt"
)
@@ -671,6 +719,9 @@ class HierarchicalTree(Tree):
# Sub-modules in the module could have arg name conflicts.
for idx, successor_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(successor_name)
if nd_inst.data.op_name in NO_CONVERTED_OPERATORS:
continue

# Generation of params must behind variable assigment.
if created:
variable_name = self._module_vars[module_key][idx]
@@ -680,6 +731,8 @@ class HierarchicalTree(Tree):

code_fragment = nd_inst.data.param_transform(mapper, variable_name)
code_fragment.declared_var_name = variable_name
code_fragment.output_var_name = nd_inst.data.opt_var_names
code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names)
self.code_fragment_recorder[nd_inst.identifier] = code_fragment

module_args.update(nd_inst.data.args_in_code)


+ 12
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py View File

@@ -34,9 +34,19 @@ class ReshapeMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_settings(**kwargs):
if kwargs.get("weights", None):
return ReshapeMapper._convert_settings_tf(**kwargs)
return ReshapeMapper._convert_settings_pytorch(**kwargs)

@staticmethod
def _convert_settings_pytorch(**kwargs):
params = kwargs.get("params")
shape = params.get("output_shape")
return Setting(op_extra_input={"input_shape": tuple(shape)})

@staticmethod
def _convert_settings_tf(**kwargs):
weights = kwargs.get("weights")
if not weights:
return Setting()
if len(weights) > 1:
raise ValueError("For reshape, `weights` length should equal to 1.")
shape = [-1]


+ 41
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class SplitMapper(ONNXToMindSporeMapper):
"""Split mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Split"

@staticmethod
def _convert_params(**kwargs):
axis = kwargs["params"]["axis"]
split = kwargs["params"]["split"]
output_num = len(split)
return {"axis": axis,
"output_num": output_num}

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -18,5 +18,6 @@
"onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper",
"onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper",
"onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper",
"onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper"
"onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper",
"onnx::Split": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.split_mapper.SplitMapper"
}

+ 23
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -105,6 +105,7 @@ class Graph(BaseGraph, abc.ABC):
self._output_nodes = []
self._topological_order = []
self._input_shape = dict()
self._is_multi_opt_graph = False

def get_input_shape(self, name):
"""
@@ -303,11 +304,33 @@ class GraphNode(abc.ABC):
self._opt_shape = None
# Weight of current op.
self._weight = None
# Input variable names.
self._ipt_var_names = list()
# Output variable names.
self._opt_var_names = list()
# Is in multi output graph.
self._is_in_multi_opt_graph = False

@property
def weight(self):
return self._weight

@property
def ipt_var_names(self):
return self._ipt_var_names

@ipt_var_names.setter
def ipt_var_names(self, var_names):
self._ipt_var_names = var_names

@property
def opt_var_names(self):
return self._opt_var_names

@opt_var_names.setter
def opt_var_names(self, var_names):
self._opt_var_names = var_names

@staticmethod
def get_opt_var_name(variable_name):
"""


+ 67
- 10
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -24,7 +24,7 @@ from .pytorch_graph_node import PyTorchGraphNode
from .pytorch_graph_parser import PyTorchGraphParser

from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \
MIN_SCOPE_LENGTH
MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT
from ..constant import LEFT_BUCKET, RIGHT_BUCKET

NONE_SCOPE_OP = {
@@ -33,22 +33,32 @@ NONE_SCOPE_OP = {
"onnx::Concat": "Concat",
"onnx::Squeeze": "Squeeze",
"onnx::Unsqueeze": "Unsqueeze",
"onnx::Split": "Split",
"onnx::Reshape": "Reshape",
"onnx::Transpose": "Transpose",
"onnx::Constant": "Constant",
"onnx::ReduceMean": "ReduceMean"
}


def normalize_scope_name(node):
def normalize_scope_name(node, scope_name_dict):
"""
Rename scope name into uniform.

Args:
node (Node): PyTorch node.
scope_name_dict (dict): Dictionary of scope names with the key node_id.

Returns:
str, normalized scope name.
"""
global NONE_SCOPE_OP

name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE)
scope_name = node.scopeName()
if not scope_name:
name = [retrieve_scope_name(node, scope_name_dict)]
else:
name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE)
scopes = []
for segment in name:
segment = segment.split(LINK_IN_SCOPE)[0]
@@ -64,7 +74,43 @@ def normalize_scope_name(node):
if node.kind() in NONE_SCOPE_OP.keys():
scopes.append(NONE_SCOPE_OP[node.kind()])
scopes = [s for s in scopes if s]
return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}"
node_id = PyTorchGraph.get_node_id(node)
return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}"


def retrieve_scope_name(node, scope_name_dict):
"""
Retrieve scope name from input nodes.

Args:
node (Node): PyTorch node.
scope_name_dict (dict): Dictionary of scope names with the key node_id.

Return:
str: Scope name.
"""
node_content = \
SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:])
node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0]
node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",")

scope_name_ipt_nodes = list()
for node_input in node_inputs:
if not scope_name_dict.get(node_input, None):
continue
scope_name_ipt_nodes.append(scope_name_dict[node_input])

scope_name_split = list()
for idx, _ in enumerate(scope_name_ipt_nodes):
if not scope_name_split:
scope_name_split = scope_name_ipt_nodes[idx]
else:
scope_name_split = [
sub_scope_name
for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx]
]
scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split)
return scope_name


class PyTorchGraph(Graph):
@@ -179,8 +225,12 @@ class PyTorchGraph(Graph):
graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes())

scope_name_dict = dict()

for node in nodes:
node_name = normalize_scope_name(node)
node_name = normalize_scope_name(node, scope_name_dict)
scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \
= list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE))
output_shape_str_list = re.findall(r'[^()!]+', str(node))
output_shape_str = output_shape_str_list[1]
output_shape = self._extract_shape(output_shape_str)
@@ -204,7 +254,7 @@ class PyTorchGraph(Graph):

if nd_id and nd_scope_name:
node_input_name = normalize_scope_name(
node_input.node()
node_input.node(), scope_name_dict
)
self.build_connection(node_input_name, node_name)

@@ -259,12 +309,16 @@ class PyTorchGraph(Graph):

return module_dict

def _check_multi_ipt(self):
def _check_multi_ipt_opt(self):
"""Check whether multi-input exists."""
module_dict = self._generate_module()
for _, nodes_per_module in module_dict.items():
prcs_nodes_out_from_module = set()
for node_name in nodes_per_module:
if re.search(r"[\d]+[&][\d]+", node_name):
self._is_multi_opt_graph = True
return True

node = self._nodes_collection.get(node_name, None)
if node:
prcs_nodes = node.precursor_nodes
@@ -284,11 +338,13 @@ class PyTorchGraph(Graph):

def _unmerge_multi_ipt_opt_script(self):
"""Unmerge all submodule."""
if self._check_multi_ipt():
if self._check_multi_ipt_opt():
for node_key, node_inst in deepcopy(self._nodes_collection).items():
prsc_nodes = node_inst.precursor_nodes
scsr_nodes = node_inst.successor_nodes

node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph

node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0],
prsc_node.split(SEPARATOR_IN_SCOPE)[-1]))
for prsc_node in deepcopy(prsc_nodes)]
@@ -382,5 +438,6 @@ class PyTorchGraph(Graph):
Returns:
str, node id.
"""
node_id = re.search(r"[\d]+", str(node))
return node_id.group()
node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0]
node_id = re.findall(r"[%](.*?) [:]", node_title)
return node_id

+ 28
- 4
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -13,11 +13,13 @@
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph node."""
import re

from .base import GraphNode
from ..common.utils import is_converted

from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP
SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT


class PyTorchGraphNode(GraphNode):
@@ -38,6 +40,19 @@ class PyTorchGraphNode(GraphNode):
self._op_name = node.kind() if node else None
self._scope_name = node.scopeName() if node else None
self._weight = weight
self._ipt_var_names, self._opt_var_names \
= self._extract_ipt_opt_var_names() if node else (list(), list())

def _extract_ipt_opt_var_names(self):
"""Extract ipt and opt var names."""
node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(
str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]
)
node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0]
node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",")
node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0]
node_outputs = re.findall(r"[%](.*?) [:]", node_title)
return node_inputs, node_outputs

def clear_args_of_declaration(self):
"""
@@ -57,6 +72,14 @@ class PyTorchGraphNode(GraphNode):
"""
return f"{arg}_{variable_name}"

@property
def is_in_multi_opt_graph(self):
return self._is_in_multi_opt_graph

@is_in_multi_opt_graph.setter
def is_in_multi_opt_graph(self, multi_opt_state):
self._is_in_multi_opt_graph = multi_opt_state

@property
def hash_key(self):
"""
@@ -119,14 +142,14 @@ class PyTorchGraphNode(GraphNode):
self._ipt_shape = input_shape
self._opt_shape = output_shape

def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment):
def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment):
"""
Generate statements.

Args:
variable_name (str): Variable name.
ipt_args_in_construct (str): Args of input.
output_var (str): Output variable name in construct.
output_var (list): Output variable names in construct.
code_fragment (CodeFragment): CodeFragment instance.

Returns:
@@ -157,7 +180,8 @@ class PyTorchGraphNode(GraphNode):
operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".")

declare = f"self.{variable_name} = {operator}({expr})"
call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})"
call = f"{', '.join([output for output in output_var])}" \
f" = self.{variable_name}({ipt_args_settings_in_construct})"

return declare, call



Loading…
Cancel
Save