Browse Source

!1016 Add multi-output converter function in PyTorch.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3d80b76841
9 changed files with 247 additions and 27 deletions
  1. +9
    -3
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +5
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +60
    -7
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  4. +12
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py
  5. +41
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py
  6. +2
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json
  7. +23
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  8. +67
    -10
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  9. +28
    -4
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py

+ 9
- 3
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -44,6 +44,7 @@ class Fragment(abc.ABC):
operation (str): Operation name in MindSpore. operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values. actual_args (dict): Actual arg values.
settings (namedTuple): Code generation setting. settings (namedTuple): Code generation setting.

""" """


def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): def __init__(self, operation, actual_args, input_shape, output_shape, settings=None):
@@ -89,9 +90,9 @@ class Fragment(abc.ABC):
self._declared_variable_name = var self._declared_variable_name = var


@property @property
def output_var_name(self) -> str:
def output_var_name(self) -> list:
"""Getter of output variable name.""" """Getter of output variable name."""
return ", ".join(self._output_var_name)
return self._output_var_name


@output_var_name.setter @output_var_name.setter
def output_var_name(self, opt_vars): def output_var_name(self, opt_vars):
@@ -100,6 +101,7 @@ class Fragment(abc.ABC):


Args: Args:
opt_vars (list[str]): Output variable name. opt_vars (list[str]): Output variable name.

""" """
self._output_var_name = opt_vars self._output_var_name = opt_vars


@@ -119,8 +121,9 @@ class Fragment(abc.ABC):


Args: Args:
ipt (Fragment): Where input comes from. ipt (Fragment): Where input comes from.

""" """
self._operation_inputs.append(ipt)
self._operation_inputs += ipt


@property @property
def operation(self): def operation(self):
@@ -139,6 +142,7 @@ class Fragment(abc.ABC):


Args: Args:
op (str): Operation name. op (str): Operation name.

""" """
self._operation = op self._operation = op


@@ -158,6 +162,7 @@ class Fragment(abc.ABC):


Args: Args:
formal_args (dict): To be updated args. formal_args (dict): To be updated args.

""" """
return self._formal_args_list.update(formal_args) return self._formal_args_list.update(formal_args)


@@ -194,6 +199,7 @@ class CodeFragment(Fragment):
operation (str): Operation name in MindSpore. operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values. actual_args (dict): Actual arg values.
settings (namedTuple): Code generation setting. settings (namedTuple): Code generation setting.

""" """


def __init__(self, operation, actual_args, settings, input_shape, output_shape, def __init__(self, operation, actual_args, settings, input_shape, output_shape,


+ 5
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -18,6 +18,7 @@ from enum import Enum, unique
SEPARATOR_IN_ONNX_OP = "::" SEPARATOR_IN_ONNX_OP = "::"
SEPARATOR_IN_SCOPE = "/" SEPARATOR_IN_SCOPE = "/"
SEPARATOR_BTW_NAME_AND_ID = "_" SEPARATOR_BTW_NAME_AND_ID = "_"
SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "="
LINK_IN_SCOPE = "-" LINK_IN_SCOPE = "-"
LEFT_BUCKET = "[" LEFT_BUCKET = "["
RIGHT_BUCKET = "]" RIGHT_BUCKET = "]"
@@ -52,6 +53,10 @@ EXPECTED_NUMBER = 1


MIN_SCOPE_LENGTH = 2 MIN_SCOPE_LENGTH = 2


NO_CONVERTED_OPERATORS = [
"onnx::Constant"
]



@unique @unique
class CodeFormatConfig(Enum): class CodeFormatConfig(Enum):


+ 60
- 7
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report
from ..mapper.base import Mapper from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..constant import SEPARATOR_IN_SCOPE, get_imported_module
from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS
from ..constant import CodeFormatConfig from ..constant import CodeFormatConfig
from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
@@ -472,6 +472,8 @@ class HierarchicalTree(Tree):


for idx, node_name in enumerate(node.successors(self.tree_identifier)): for idx, node_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(node_name) nd_inst = self.get_node(node_name)
if nd_inst.data.op_name in NO_CONVERTED_OPERATORS:
continue


# Generate code statement. # Generate code statement.
init, construct = self._generate_stat(nd_inst, node, idx) init, construct = self._generate_stat(nd_inst, node, idx)
@@ -518,14 +520,25 @@ class HierarchicalTree(Tree):
""" """


ipt_args_in_construct = "x" ipt_args_in_construct = "x"
opt_arg_in_construct = "output"
opt_arg_in_construct = ["output"]


if idx != 0: if idx != 0:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
if cur_nd_inst.data.is_in_multi_opt_graph:
ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst)
else:
# Get previous node output variable name.
ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst)
if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1:
# Set opt variable name. # Set opt variable name.
opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt"
if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph:
opt_arg_in_construct = [
f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt"
]
else:
opt_arg_in_construct = [
f"opt_{var_name}"
for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name
]


declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct,
variable_name=self.code_fragment_recorder[ variable_name=self.code_fragment_recorder[
@@ -548,6 +561,39 @@ class HierarchicalTree(Tree):
""" """
return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0]


def _get_current_ipt_var(self, cur_nd):
""""
Get current input variable name from node_id.

Args:
cur_nd (Node): Current node.

Returns:
str, needed var names.
"""
if cur_nd.data.node_type != NodeType.OPERATION.value:
while True:
p_nd = cur_nd.successors(self.tree_identifier)
if not p_nd:
break
cur_nd = self.get_node(p_nd[0])

ipt_lst_raw = []
for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs:
ipt_lst_raw.append(f"{operation_input}")

opt_var_names_p_nds = set()
for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e)
if p_nd.data.op_name in NO_CONVERTED_OPERATORS:
continue

opt_var_names_p_nd = set(p_nd.data.opt_var_names)
opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd)

ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)]
return ", ".join(ipt_lst)

def _find_all_previous_opt_var_(self, cur_nd, pre_nd): def _find_all_previous_opt_var_(self, cur_nd, pre_nd):
""" """
Find all input variable names. Find all input variable names.
@@ -557,9 +603,12 @@ class HierarchicalTree(Tree):
pre_nd (Node): Precursor node. pre_nd (Node): Precursor node.


Returns: Returns:
str, needed var names.
list, needed var names list.
""" """
ipt_lst = [] ipt_lst = []
if cur_nd.tag in NO_CONVERTED_OPERATORS:
return ipt_lst

for e in cur_nd.data.precursor_nodes: for e in cur_nd.data.precursor_nodes:
p_nd = self.get_node(e) p_nd = self.get_node(e)
if e not in pre_nd.successors(self.tree_identifier): if e not in pre_nd.successors(self.tree_identifier):
@@ -575,7 +624,6 @@ class HierarchicalTree(Tree):
break break
p_nd = self.get_node(pre_nd_name) p_nd = self.get_node(pre_nd_name)
continue continue

ipt_lst.append( ipt_lst.append(
f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt"
) )
@@ -671,6 +719,9 @@ class HierarchicalTree(Tree):
# Sub-modules in the module could have arg name conflicts. # Sub-modules in the module could have arg name conflicts.
for idx, successor_name in enumerate(node.successors(self.tree_identifier)): for idx, successor_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(successor_name) nd_inst = self.get_node(successor_name)
if nd_inst.data.op_name in NO_CONVERTED_OPERATORS:
continue

# Generation of params must behind variable assigment. # Generation of params must behind variable assigment.
if created: if created:
variable_name = self._module_vars[module_key][idx] variable_name = self._module_vars[module_key][idx]
@@ -680,6 +731,8 @@ class HierarchicalTree(Tree):


code_fragment = nd_inst.data.param_transform(mapper, variable_name) code_fragment = nd_inst.data.param_transform(mapper, variable_name)
code_fragment.declared_var_name = variable_name code_fragment.declared_var_name = variable_name
code_fragment.output_var_name = nd_inst.data.opt_var_names
code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names)
self.code_fragment_recorder[nd_inst.identifier] = code_fragment self.code_fragment_recorder[nd_inst.identifier] = code_fragment


module_args.update(nd_inst.data.args_in_code) module_args.update(nd_inst.data.args_in_code)


+ 12
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py View File

@@ -34,9 +34,19 @@ class ReshapeMapper(ONNXToMindSporeMapper):


@staticmethod @staticmethod
def _convert_settings(**kwargs): def _convert_settings(**kwargs):
if kwargs.get("weights", None):
return ReshapeMapper._convert_settings_tf(**kwargs)
return ReshapeMapper._convert_settings_pytorch(**kwargs)

@staticmethod
def _convert_settings_pytorch(**kwargs):
params = kwargs.get("params")
shape = params.get("output_shape")
return Setting(op_extra_input={"input_shape": tuple(shape)})

@staticmethod
def _convert_settings_tf(**kwargs):
weights = kwargs.get("weights") weights = kwargs.get("weights")
if not weights:
return Setting()
if len(weights) > 1: if len(weights) > 1:
raise ValueError("For reshape, `weights` length should equal to 1.") raise ValueError("For reshape, `weights` length should equal to 1.")
shape = [-1] shape = [-1]


+ 41
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from ...base import ONNXToMindSporeMapper
from ...gen_setting import Setting


class SplitMapper(ONNXToMindSporeMapper):
"""Split mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Split"

@staticmethod
def _convert_params(**kwargs):
axis = kwargs["params"]["axis"]
split = kwargs["params"]["split"]
output_num = len(split)
return {"axis": axis,
"output_num": output_num}

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -18,5 +18,6 @@
"onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper", "onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper",
"onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper", "onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper",
"onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper", "onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper",
"onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper"
"onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper",
"onnx::Split": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.split_mapper.SplitMapper"
} }

+ 23
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -105,6 +105,7 @@ class Graph(BaseGraph, abc.ABC):
self._output_nodes = [] self._output_nodes = []
self._topological_order = [] self._topological_order = []
self._input_shape = dict() self._input_shape = dict()
self._is_multi_opt_graph = False


def get_input_shape(self, name): def get_input_shape(self, name):
""" """
@@ -303,11 +304,33 @@ class GraphNode(abc.ABC):
self._opt_shape = None self._opt_shape = None
# Weight of current op. # Weight of current op.
self._weight = None self._weight = None
# Input variable names.
self._ipt_var_names = list()
# Output variable names.
self._opt_var_names = list()
# Is in multi output graph.
self._is_in_multi_opt_graph = False


@property @property
def weight(self): def weight(self):
return self._weight return self._weight


@property
def ipt_var_names(self):
return self._ipt_var_names

@ipt_var_names.setter
def ipt_var_names(self, var_names):
self._ipt_var_names = var_names

@property
def opt_var_names(self):
return self._opt_var_names

@opt_var_names.setter
def opt_var_names(self, var_names):
self._opt_var_names = var_names

@staticmethod @staticmethod
def get_opt_var_name(variable_name): def get_opt_var_name(variable_name):
""" """


+ 67
- 10
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -24,7 +24,7 @@ from .pytorch_graph_node import PyTorchGraphNode
from .pytorch_graph_parser import PyTorchGraphParser from .pytorch_graph_parser import PyTorchGraphParser


from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \
MIN_SCOPE_LENGTH
MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT
from ..constant import LEFT_BUCKET, RIGHT_BUCKET from ..constant import LEFT_BUCKET, RIGHT_BUCKET


NONE_SCOPE_OP = { NONE_SCOPE_OP = {
@@ -33,22 +33,32 @@ NONE_SCOPE_OP = {
"onnx::Concat": "Concat", "onnx::Concat": "Concat",
"onnx::Squeeze": "Squeeze", "onnx::Squeeze": "Squeeze",
"onnx::Unsqueeze": "Unsqueeze", "onnx::Unsqueeze": "Unsqueeze",
"onnx::Split": "Split",
"onnx::Reshape": "Reshape",
"onnx::Transpose": "Transpose",
"onnx::Constant": "Constant",
"onnx::ReduceMean": "ReduceMean"
} }




def normalize_scope_name(node):
def normalize_scope_name(node, scope_name_dict):
""" """
Rename scope name into uniform. Rename scope name into uniform.


Args: Args:
node (Node): PyTorch node. node (Node): PyTorch node.
scope_name_dict (dict): Dictionary of scope names with the key node_id.


Returns: Returns:
str, normalized scope name. str, normalized scope name.
""" """
global NONE_SCOPE_OP global NONE_SCOPE_OP


name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE)
scope_name = node.scopeName()
if not scope_name:
name = [retrieve_scope_name(node, scope_name_dict)]
else:
name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE)
scopes = [] scopes = []
for segment in name: for segment in name:
segment = segment.split(LINK_IN_SCOPE)[0] segment = segment.split(LINK_IN_SCOPE)[0]
@@ -64,7 +74,43 @@ def normalize_scope_name(node):
if node.kind() in NONE_SCOPE_OP.keys(): if node.kind() in NONE_SCOPE_OP.keys():
scopes.append(NONE_SCOPE_OP[node.kind()]) scopes.append(NONE_SCOPE_OP[node.kind()])
scopes = [s for s in scopes if s] scopes = [s for s in scopes if s]
return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}"
node_id = PyTorchGraph.get_node_id(node)
return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}"


def retrieve_scope_name(node, scope_name_dict):
"""
Retrieve scope name from input nodes.

Args:
node (Node): PyTorch node.
scope_name_dict (dict): Dictionary of scope names with the key node_id.

Return:
str: Scope name.
"""
node_content = \
SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:])
node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0]
node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",")

scope_name_ipt_nodes = list()
for node_input in node_inputs:
if not scope_name_dict.get(node_input, None):
continue
scope_name_ipt_nodes.append(scope_name_dict[node_input])

scope_name_split = list()
for idx, _ in enumerate(scope_name_ipt_nodes):
if not scope_name_split:
scope_name_split = scope_name_ipt_nodes[idx]
else:
scope_name_split = [
sub_scope_name
for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx]
]
scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split)
return scope_name




class PyTorchGraph(Graph): class PyTorchGraph(Graph):
@@ -179,8 +225,12 @@ class PyTorchGraph(Graph):
graph = self._trace_torch_graph(feed_forward_ipt_shape) graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes()) nodes = list(graph.nodes())


scope_name_dict = dict()

for node in nodes: for node in nodes:
node_name = normalize_scope_name(node)
node_name = normalize_scope_name(node, scope_name_dict)
scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \
= list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE))
output_shape_str_list = re.findall(r'[^()!]+', str(node)) output_shape_str_list = re.findall(r'[^()!]+', str(node))
output_shape_str = output_shape_str_list[1] output_shape_str = output_shape_str_list[1]
output_shape = self._extract_shape(output_shape_str) output_shape = self._extract_shape(output_shape_str)
@@ -204,7 +254,7 @@ class PyTorchGraph(Graph):


if nd_id and nd_scope_name: if nd_id and nd_scope_name:
node_input_name = normalize_scope_name( node_input_name = normalize_scope_name(
node_input.node()
node_input.node(), scope_name_dict
) )
self.build_connection(node_input_name, node_name) self.build_connection(node_input_name, node_name)


@@ -259,12 +309,16 @@ class PyTorchGraph(Graph):


return module_dict return module_dict


def _check_multi_ipt(self):
def _check_multi_ipt_opt(self):
"""Check whether multi-input exists.""" """Check whether multi-input exists."""
module_dict = self._generate_module() module_dict = self._generate_module()
for _, nodes_per_module in module_dict.items(): for _, nodes_per_module in module_dict.items():
prcs_nodes_out_from_module = set() prcs_nodes_out_from_module = set()
for node_name in nodes_per_module: for node_name in nodes_per_module:
if re.search(r"[\d]+[&][\d]+", node_name):
self._is_multi_opt_graph = True
return True

node = self._nodes_collection.get(node_name, None) node = self._nodes_collection.get(node_name, None)
if node: if node:
prcs_nodes = node.precursor_nodes prcs_nodes = node.precursor_nodes
@@ -284,11 +338,13 @@ class PyTorchGraph(Graph):


def _unmerge_multi_ipt_opt_script(self): def _unmerge_multi_ipt_opt_script(self):
"""Unmerge all submodule.""" """Unmerge all submodule."""
if self._check_multi_ipt():
if self._check_multi_ipt_opt():
for node_key, node_inst in deepcopy(self._nodes_collection).items(): for node_key, node_inst in deepcopy(self._nodes_collection).items():
prsc_nodes = node_inst.precursor_nodes prsc_nodes = node_inst.precursor_nodes
scsr_nodes = node_inst.successor_nodes scsr_nodes = node_inst.successor_nodes


node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph

node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0],
prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) prsc_node.split(SEPARATOR_IN_SCOPE)[-1]))
for prsc_node in deepcopy(prsc_nodes)] for prsc_node in deepcopy(prsc_nodes)]
@@ -382,5 +438,6 @@ class PyTorchGraph(Graph):
Returns: Returns:
str, node id. str, node id.
""" """
node_id = re.search(r"[\d]+", str(node))
return node_id.group()
node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0]
node_id = re.findall(r"[%](.*?) [:]", node_title)
return node_id

+ 28
- 4
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Define PyTorch graph node.""" """Define PyTorch graph node."""
import re

from .base import GraphNode from .base import GraphNode
from ..common.utils import is_converted from ..common.utils import is_converted


from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP
SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT




class PyTorchGraphNode(GraphNode): class PyTorchGraphNode(GraphNode):
@@ -38,6 +40,19 @@ class PyTorchGraphNode(GraphNode):
self._op_name = node.kind() if node else None self._op_name = node.kind() if node else None
self._scope_name = node.scopeName() if node else None self._scope_name = node.scopeName() if node else None
self._weight = weight self._weight = weight
self._ipt_var_names, self._opt_var_names \
= self._extract_ipt_opt_var_names() if node else (list(), list())

def _extract_ipt_opt_var_names(self):
"""Extract ipt and opt var names."""
node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(
str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]
)
node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0]
node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",")
node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0]
node_outputs = re.findall(r"[%](.*?) [:]", node_title)
return node_inputs, node_outputs


def clear_args_of_declaration(self): def clear_args_of_declaration(self):
""" """
@@ -57,6 +72,14 @@ class PyTorchGraphNode(GraphNode):
""" """
return f"{arg}_{variable_name}" return f"{arg}_{variable_name}"


@property
def is_in_multi_opt_graph(self):
return self._is_in_multi_opt_graph

@is_in_multi_opt_graph.setter
def is_in_multi_opt_graph(self, multi_opt_state):
self._is_in_multi_opt_graph = multi_opt_state

@property @property
def hash_key(self): def hash_key(self):
""" """
@@ -119,14 +142,14 @@ class PyTorchGraphNode(GraphNode):
self._ipt_shape = input_shape self._ipt_shape = input_shape
self._opt_shape = output_shape self._opt_shape = output_shape


def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment):
def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment):
""" """
Generate statements. Generate statements.


Args: Args:
variable_name (str): Variable name. variable_name (str): Variable name.
ipt_args_in_construct (str): Args of input. ipt_args_in_construct (str): Args of input.
output_var (str): Output variable name in construct.
output_var (list): Output variable names in construct.
code_fragment (CodeFragment): CodeFragment instance. code_fragment (CodeFragment): CodeFragment instance.


Returns: Returns:
@@ -157,7 +180,8 @@ class PyTorchGraphNode(GraphNode):
operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".")


declare = f"self.{variable_name} = {operator}({expr})" declare = f"self.{variable_name} = {operator}({expr})"
call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})"
call = f"{', '.join([output for output in output_var])}" \
f" = self.{variable_name}({ipt_args_settings_in_construct})"


return declare, call return declare, call




Loading…
Cancel
Save