From: @moran3 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -44,6 +44,7 @@ class Fragment(abc.ABC): | |||||
| operation (str): Operation name in MindSpore. | operation (str): Operation name in MindSpore. | ||||
| actual_args (dict): Actual arg values. | actual_args (dict): Actual arg values. | ||||
| settings (namedTuple): Code generation setting. | settings (namedTuple): Code generation setting. | ||||
| """ | """ | ||||
| def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): | def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): | ||||
| @@ -89,9 +90,9 @@ class Fragment(abc.ABC): | |||||
| self._declared_variable_name = var | self._declared_variable_name = var | ||||
| @property | @property | ||||
| def output_var_name(self) -> str: | |||||
| def output_var_name(self) -> list: | |||||
| """Getter of output variable name.""" | """Getter of output variable name.""" | ||||
| return ", ".join(self._output_var_name) | |||||
| return self._output_var_name | |||||
| @output_var_name.setter | @output_var_name.setter | ||||
| def output_var_name(self, opt_vars): | def output_var_name(self, opt_vars): | ||||
| @@ -100,6 +101,7 @@ class Fragment(abc.ABC): | |||||
| Args: | Args: | ||||
| opt_vars (list[str]): Output variable name. | opt_vars (list[str]): Output variable name. | ||||
| """ | """ | ||||
| self._output_var_name = opt_vars | self._output_var_name = opt_vars | ||||
| @@ -119,8 +121,9 @@ class Fragment(abc.ABC): | |||||
| Args: | Args: | ||||
| ipt (Fragment): Where input comes from. | ipt (Fragment): Where input comes from. | ||||
| """ | """ | ||||
| self._operation_inputs.append(ipt) | |||||
| self._operation_inputs += ipt | |||||
| @property | @property | ||||
| def operation(self): | def operation(self): | ||||
| @@ -139,6 +142,7 @@ class Fragment(abc.ABC): | |||||
| Args: | Args: | ||||
| op (str): Operation name. | op (str): Operation name. | ||||
| """ | """ | ||||
| self._operation = op | self._operation = op | ||||
| @@ -158,6 +162,7 @@ class Fragment(abc.ABC): | |||||
| Args: | Args: | ||||
| formal_args (dict): To be updated args. | formal_args (dict): To be updated args. | ||||
| """ | """ | ||||
| return self._formal_args_list.update(formal_args) | return self._formal_args_list.update(formal_args) | ||||
| @@ -194,6 +199,7 @@ class CodeFragment(Fragment): | |||||
| operation (str): Operation name in MindSpore. | operation (str): Operation name in MindSpore. | ||||
| actual_args (dict): Actual arg values. | actual_args (dict): Actual arg values. | ||||
| settings (namedTuple): Code generation setting. | settings (namedTuple): Code generation setting. | ||||
| """ | """ | ||||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape, | def __init__(self, operation, actual_args, settings, input_shape, output_shape, | ||||
| @@ -18,6 +18,7 @@ from enum import Enum, unique | |||||
| SEPARATOR_IN_ONNX_OP = "::" | SEPARATOR_IN_ONNX_OP = "::" | ||||
| SEPARATOR_IN_SCOPE = "/" | SEPARATOR_IN_SCOPE = "/" | ||||
| SEPARATOR_BTW_NAME_AND_ID = "_" | SEPARATOR_BTW_NAME_AND_ID = "_" | ||||
| SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "=" | |||||
| LINK_IN_SCOPE = "-" | LINK_IN_SCOPE = "-" | ||||
| LEFT_BUCKET = "[" | LEFT_BUCKET = "[" | ||||
| RIGHT_BUCKET = "]" | RIGHT_BUCKET = "]" | ||||
| @@ -52,6 +53,10 @@ EXPECTED_NUMBER = 1 | |||||
| MIN_SCOPE_LENGTH = 2 | MIN_SCOPE_LENGTH = 2 | ||||
| NO_CONVERTED_OPERATORS = [ | |||||
| "onnx::Constant" | |||||
| ] | |||||
| @unique | @unique | ||||
| class CodeFormatConfig(Enum): | class CodeFormatConfig(Enum): | ||||
| @@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report | |||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | ||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| from ..constant import SEPARATOR_IN_SCOPE, get_imported_module | |||||
| from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS | |||||
| from ..constant import CodeFormatConfig | from ..constant import CodeFormatConfig | ||||
| from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | ||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | ||||
| @@ -472,6 +472,8 @@ class HierarchicalTree(Tree): | |||||
| for idx, node_name in enumerate(node.successors(self.tree_identifier)): | for idx, node_name in enumerate(node.successors(self.tree_identifier)): | ||||
| nd_inst = self.get_node(node_name) | nd_inst = self.get_node(node_name) | ||||
| if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: | |||||
| continue | |||||
| # Generate code statement. | # Generate code statement. | ||||
| init, construct = self._generate_stat(nd_inst, node, idx) | init, construct = self._generate_stat(nd_inst, node, idx) | ||||
| @@ -518,14 +520,25 @@ class HierarchicalTree(Tree): | |||||
| """ | """ | ||||
| ipt_args_in_construct = "x" | ipt_args_in_construct = "x" | ||||
| opt_arg_in_construct = "output" | |||||
| opt_arg_in_construct = ["output"] | |||||
| if idx != 0: | if idx != 0: | ||||
| # Get previous node output variable name. | |||||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||||
| if cur_nd_inst.data.is_in_multi_opt_graph: | |||||
| ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst) | |||||
| else: | |||||
| # Get previous node output variable name. | |||||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||||
| if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | ||||
| # Set opt variable name. | # Set opt variable name. | ||||
| opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" | |||||
| if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph: | |||||
| opt_arg_in_construct = [ | |||||
| f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" | |||||
| ] | |||||
| else: | |||||
| opt_arg_in_construct = [ | |||||
| f"opt_{var_name}" | |||||
| for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name | |||||
| ] | |||||
| declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | ||||
| variable_name=self.code_fragment_recorder[ | variable_name=self.code_fragment_recorder[ | ||||
| @@ -548,6 +561,39 @@ class HierarchicalTree(Tree): | |||||
| """ | """ | ||||
| return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] | return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] | ||||
| def _get_current_ipt_var(self, cur_nd): | |||||
| """" | |||||
| Get current input variable name from node_id. | |||||
| Args: | |||||
| cur_nd (Node): Current node. | |||||
| Returns: | |||||
| str, needed var names. | |||||
| """ | |||||
| if cur_nd.data.node_type != NodeType.OPERATION.value: | |||||
| while True: | |||||
| p_nd = cur_nd.successors(self.tree_identifier) | |||||
| if not p_nd: | |||||
| break | |||||
| cur_nd = self.get_node(p_nd[0]) | |||||
| ipt_lst_raw = [] | |||||
| for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs: | |||||
| ipt_lst_raw.append(f"{operation_input}") | |||||
| opt_var_names_p_nds = set() | |||||
| for e in cur_nd.data.precursor_nodes: | |||||
| p_nd = self.get_node(e) | |||||
| if p_nd.data.op_name in NO_CONVERTED_OPERATORS: | |||||
| continue | |||||
| opt_var_names_p_nd = set(p_nd.data.opt_var_names) | |||||
| opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd) | |||||
| ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)] | |||||
| return ", ".join(ipt_lst) | |||||
| def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | ||||
| """ | """ | ||||
| Find all input variable names. | Find all input variable names. | ||||
| @@ -557,9 +603,12 @@ class HierarchicalTree(Tree): | |||||
| pre_nd (Node): Precursor node. | pre_nd (Node): Precursor node. | ||||
| Returns: | Returns: | ||||
| str, needed var names. | |||||
| list, needed var names list. | |||||
| """ | """ | ||||
| ipt_lst = [] | ipt_lst = [] | ||||
| if cur_nd.tag in NO_CONVERTED_OPERATORS: | |||||
| return ipt_lst | |||||
| for e in cur_nd.data.precursor_nodes: | for e in cur_nd.data.precursor_nodes: | ||||
| p_nd = self.get_node(e) | p_nd = self.get_node(e) | ||||
| if e not in pre_nd.successors(self.tree_identifier): | if e not in pre_nd.successors(self.tree_identifier): | ||||
| @@ -575,7 +624,6 @@ class HierarchicalTree(Tree): | |||||
| break | break | ||||
| p_nd = self.get_node(pre_nd_name) | p_nd = self.get_node(pre_nd_name) | ||||
| continue | continue | ||||
| ipt_lst.append( | ipt_lst.append( | ||||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | ||||
| ) | ) | ||||
| @@ -671,6 +719,9 @@ class HierarchicalTree(Tree): | |||||
| # Sub-modules in the module could have arg name conflicts. | # Sub-modules in the module could have arg name conflicts. | ||||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | ||||
| nd_inst = self.get_node(successor_name) | nd_inst = self.get_node(successor_name) | ||||
| if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: | |||||
| continue | |||||
| # Generation of params must behind variable assigment. | # Generation of params must behind variable assigment. | ||||
| if created: | if created: | ||||
| variable_name = self._module_vars[module_key][idx] | variable_name = self._module_vars[module_key][idx] | ||||
| @@ -680,6 +731,8 @@ class HierarchicalTree(Tree): | |||||
| code_fragment = nd_inst.data.param_transform(mapper, variable_name) | code_fragment = nd_inst.data.param_transform(mapper, variable_name) | ||||
| code_fragment.declared_var_name = variable_name | code_fragment.declared_var_name = variable_name | ||||
| code_fragment.output_var_name = nd_inst.data.opt_var_names | |||||
| code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names) | |||||
| self.code_fragment_recorder[nd_inst.identifier] = code_fragment | self.code_fragment_recorder[nd_inst.identifier] = code_fragment | ||||
| module_args.update(nd_inst.data.args_in_code) | module_args.update(nd_inst.data.args_in_code) | ||||
| @@ -34,9 +34,19 @@ class ReshapeMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | def _convert_settings(**kwargs): | ||||
| if kwargs.get("weights", None): | |||||
| return ReshapeMapper._convert_settings_tf(**kwargs) | |||||
| return ReshapeMapper._convert_settings_pytorch(**kwargs) | |||||
| @staticmethod | |||||
| def _convert_settings_pytorch(**kwargs): | |||||
| params = kwargs.get("params") | |||||
| shape = params.get("output_shape") | |||||
| return Setting(op_extra_input={"input_shape": tuple(shape)}) | |||||
| @staticmethod | |||||
| def _convert_settings_tf(**kwargs): | |||||
| weights = kwargs.get("weights") | weights = kwargs.get("weights") | ||||
| if not weights: | |||||
| return Setting() | |||||
| if len(weights) > 1: | if len(weights) > 1: | ||||
| raise ValueError("For reshape, `weights` length should equal to 1.") | raise ValueError("For reshape, `weights` length should equal to 1.") | ||||
| shape = [-1] | shape = [-1] | ||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Mapper module.""" | |||||
| from ...base import ONNXToMindSporeMapper | |||||
| from ...gen_setting import Setting | |||||
| class SplitMapper(ONNXToMindSporeMapper): | |||||
| """Split mapper.""" | |||||
| @staticmethod | |||||
| def _operation_name_in_ms(*args, **kwargs): | |||||
| return "P.Split" | |||||
| @staticmethod | |||||
| def _convert_params(**kwargs): | |||||
| axis = kwargs["params"]["axis"] | |||||
| split = kwargs["params"]["split"] | |||||
| output_num = len(split) | |||||
| return {"axis": axis, | |||||
| "output_num": output_num} | |||||
| @staticmethod | |||||
| def _convert_trained_weights(**kwargs): | |||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return Setting() | |||||
| @@ -18,5 +18,6 @@ | |||||
| "onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper", | "onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper", | ||||
| "onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper", | "onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper", | ||||
| "onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper", | "onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper", | ||||
| "onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper" | |||||
| "onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper", | |||||
| "onnx::Split": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.split_mapper.SplitMapper" | |||||
| } | } | ||||
| @@ -105,6 +105,7 @@ class Graph(BaseGraph, abc.ABC): | |||||
| self._output_nodes = [] | self._output_nodes = [] | ||||
| self._topological_order = [] | self._topological_order = [] | ||||
| self._input_shape = dict() | self._input_shape = dict() | ||||
| self._is_multi_opt_graph = False | |||||
| def get_input_shape(self, name): | def get_input_shape(self, name): | ||||
| """ | """ | ||||
| @@ -303,11 +304,33 @@ class GraphNode(abc.ABC): | |||||
| self._opt_shape = None | self._opt_shape = None | ||||
| # Weight of current op. | # Weight of current op. | ||||
| self._weight = None | self._weight = None | ||||
| # Input variable names. | |||||
| self._ipt_var_names = list() | |||||
| # Output variable names. | |||||
| self._opt_var_names = list() | |||||
| # Is in multi output graph. | |||||
| self._is_in_multi_opt_graph = False | |||||
| @property | @property | ||||
| def weight(self): | def weight(self): | ||||
| return self._weight | return self._weight | ||||
| @property | |||||
| def ipt_var_names(self): | |||||
| return self._ipt_var_names | |||||
| @ipt_var_names.setter | |||||
| def ipt_var_names(self, var_names): | |||||
| self._ipt_var_names = var_names | |||||
| @property | |||||
| def opt_var_names(self): | |||||
| return self._opt_var_names | |||||
| @opt_var_names.setter | |||||
| def opt_var_names(self, var_names): | |||||
| self._opt_var_names = var_names | |||||
| @staticmethod | @staticmethod | ||||
| def get_opt_var_name(variable_name): | def get_opt_var_name(variable_name): | ||||
| """ | """ | ||||
| @@ -24,7 +24,7 @@ from .pytorch_graph_node import PyTorchGraphNode | |||||
| from .pytorch_graph_parser import PyTorchGraphParser | from .pytorch_graph_parser import PyTorchGraphParser | ||||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ | from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ | ||||
| MIN_SCOPE_LENGTH | |||||
| MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT | |||||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | from ..constant import LEFT_BUCKET, RIGHT_BUCKET | ||||
| NONE_SCOPE_OP = { | NONE_SCOPE_OP = { | ||||
| @@ -33,22 +33,32 @@ NONE_SCOPE_OP = { | |||||
| "onnx::Concat": "Concat", | "onnx::Concat": "Concat", | ||||
| "onnx::Squeeze": "Squeeze", | "onnx::Squeeze": "Squeeze", | ||||
| "onnx::Unsqueeze": "Unsqueeze", | "onnx::Unsqueeze": "Unsqueeze", | ||||
| "onnx::Split": "Split", | |||||
| "onnx::Reshape": "Reshape", | |||||
| "onnx::Transpose": "Transpose", | |||||
| "onnx::Constant": "Constant", | |||||
| "onnx::ReduceMean": "ReduceMean" | |||||
| } | } | ||||
| def normalize_scope_name(node): | |||||
| def normalize_scope_name(node, scope_name_dict): | |||||
| """ | """ | ||||
| Rename scope name into uniform. | Rename scope name into uniform. | ||||
| Args: | Args: | ||||
| node (Node): PyTorch node. | node (Node): PyTorch node. | ||||
| scope_name_dict (dict): Dictionary of scope names with the key node_id. | |||||
| Returns: | Returns: | ||||
| str, normalized scope name. | str, normalized scope name. | ||||
| """ | """ | ||||
| global NONE_SCOPE_OP | global NONE_SCOPE_OP | ||||
| name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) | |||||
| scope_name = node.scopeName() | |||||
| if not scope_name: | |||||
| name = [retrieve_scope_name(node, scope_name_dict)] | |||||
| else: | |||||
| name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) | |||||
| scopes = [] | scopes = [] | ||||
| for segment in name: | for segment in name: | ||||
| segment = segment.split(LINK_IN_SCOPE)[0] | segment = segment.split(LINK_IN_SCOPE)[0] | ||||
| @@ -64,7 +74,43 @@ def normalize_scope_name(node): | |||||
| if node.kind() in NONE_SCOPE_OP.keys(): | if node.kind() in NONE_SCOPE_OP.keys(): | ||||
| scopes.append(NONE_SCOPE_OP[node.kind()]) | scopes.append(NONE_SCOPE_OP[node.kind()]) | ||||
| scopes = [s for s in scopes if s] | scopes = [s for s in scopes if s] | ||||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}" | |||||
| node_id = PyTorchGraph.get_node_id(node) | |||||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}" | |||||
| def retrieve_scope_name(node, scope_name_dict): | |||||
| """ | |||||
| Retrieve scope name from input nodes. | |||||
| Args: | |||||
| node (Node): PyTorch node. | |||||
| scope_name_dict (dict): Dictionary of scope names with the key node_id. | |||||
| Return: | |||||
| str: Scope name. | |||||
| """ | |||||
| node_content = \ | |||||
| SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]) | |||||
| node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] | |||||
| node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") | |||||
| scope_name_ipt_nodes = list() | |||||
| for node_input in node_inputs: | |||||
| if not scope_name_dict.get(node_input, None): | |||||
| continue | |||||
| scope_name_ipt_nodes.append(scope_name_dict[node_input]) | |||||
| scope_name_split = list() | |||||
| for idx, _ in enumerate(scope_name_ipt_nodes): | |||||
| if not scope_name_split: | |||||
| scope_name_split = scope_name_ipt_nodes[idx] | |||||
| else: | |||||
| scope_name_split = [ | |||||
| sub_scope_name | |||||
| for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx] | |||||
| ] | |||||
| scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split) | |||||
| return scope_name | |||||
| class PyTorchGraph(Graph): | class PyTorchGraph(Graph): | ||||
| @@ -179,8 +225,12 @@ class PyTorchGraph(Graph): | |||||
| graph = self._trace_torch_graph(feed_forward_ipt_shape) | graph = self._trace_torch_graph(feed_forward_ipt_shape) | ||||
| nodes = list(graph.nodes()) | nodes = list(graph.nodes()) | ||||
| scope_name_dict = dict() | |||||
| for node in nodes: | for node in nodes: | ||||
| node_name = normalize_scope_name(node) | |||||
| node_name = normalize_scope_name(node, scope_name_dict) | |||||
| scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \ | |||||
| = list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE)) | |||||
| output_shape_str_list = re.findall(r'[^()!]+', str(node)) | output_shape_str_list = re.findall(r'[^()!]+', str(node)) | ||||
| output_shape_str = output_shape_str_list[1] | output_shape_str = output_shape_str_list[1] | ||||
| output_shape = self._extract_shape(output_shape_str) | output_shape = self._extract_shape(output_shape_str) | ||||
| @@ -204,7 +254,7 @@ class PyTorchGraph(Graph): | |||||
| if nd_id and nd_scope_name: | if nd_id and nd_scope_name: | ||||
| node_input_name = normalize_scope_name( | node_input_name = normalize_scope_name( | ||||
| node_input.node() | |||||
| node_input.node(), scope_name_dict | |||||
| ) | ) | ||||
| self.build_connection(node_input_name, node_name) | self.build_connection(node_input_name, node_name) | ||||
| @@ -259,12 +309,16 @@ class PyTorchGraph(Graph): | |||||
| return module_dict | return module_dict | ||||
| def _check_multi_ipt(self): | |||||
| def _check_multi_ipt_opt(self): | |||||
| """Check whether multi-input exists.""" | """Check whether multi-input exists.""" | ||||
| module_dict = self._generate_module() | module_dict = self._generate_module() | ||||
| for _, nodes_per_module in module_dict.items(): | for _, nodes_per_module in module_dict.items(): | ||||
| prcs_nodes_out_from_module = set() | prcs_nodes_out_from_module = set() | ||||
| for node_name in nodes_per_module: | for node_name in nodes_per_module: | ||||
| if re.search(r"[\d]+[&][\d]+", node_name): | |||||
| self._is_multi_opt_graph = True | |||||
| return True | |||||
| node = self._nodes_collection.get(node_name, None) | node = self._nodes_collection.get(node_name, None) | ||||
| if node: | if node: | ||||
| prcs_nodes = node.precursor_nodes | prcs_nodes = node.precursor_nodes | ||||
| @@ -284,11 +338,13 @@ class PyTorchGraph(Graph): | |||||
| def _unmerge_multi_ipt_opt_script(self): | def _unmerge_multi_ipt_opt_script(self): | ||||
| """Unmerge all submodule.""" | """Unmerge all submodule.""" | ||||
| if self._check_multi_ipt(): | |||||
| if self._check_multi_ipt_opt(): | |||||
| for node_key, node_inst in deepcopy(self._nodes_collection).items(): | for node_key, node_inst in deepcopy(self._nodes_collection).items(): | ||||
| prsc_nodes = node_inst.precursor_nodes | prsc_nodes = node_inst.precursor_nodes | ||||
| scsr_nodes = node_inst.successor_nodes | scsr_nodes = node_inst.successor_nodes | ||||
| node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph | |||||
| node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], | node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], | ||||
| prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) | prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) | ||||
| for prsc_node in deepcopy(prsc_nodes)] | for prsc_node in deepcopy(prsc_nodes)] | ||||
| @@ -382,5 +438,6 @@ class PyTorchGraph(Graph): | |||||
| Returns: | Returns: | ||||
| str, node id. | str, node id. | ||||
| """ | """ | ||||
| node_id = re.search(r"[\d]+", str(node)) | |||||
| return node_id.group() | |||||
| node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] | |||||
| node_id = re.findall(r"[%](.*?) [:]", node_title) | |||||
| return node_id | |||||
| @@ -13,11 +13,13 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define PyTorch graph node.""" | """Define PyTorch graph node.""" | ||||
| import re | |||||
| from .base import GraphNode | from .base import GraphNode | ||||
| from ..common.utils import is_converted | from ..common.utils import is_converted | ||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | ||||
| SEPARATOR_IN_ONNX_OP | |||||
| SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT | |||||
| class PyTorchGraphNode(GraphNode): | class PyTorchGraphNode(GraphNode): | ||||
| @@ -38,6 +40,19 @@ class PyTorchGraphNode(GraphNode): | |||||
| self._op_name = node.kind() if node else None | self._op_name = node.kind() if node else None | ||||
| self._scope_name = node.scopeName() if node else None | self._scope_name = node.scopeName() if node else None | ||||
| self._weight = weight | self._weight = weight | ||||
| self._ipt_var_names, self._opt_var_names \ | |||||
| = self._extract_ipt_opt_var_names() if node else (list(), list()) | |||||
| def _extract_ipt_opt_var_names(self): | |||||
| """Extract ipt and opt var names.""" | |||||
| node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join( | |||||
| str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:] | |||||
| ) | |||||
| node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] | |||||
| node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") | |||||
| node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] | |||||
| node_outputs = re.findall(r"[%](.*?) [:]", node_title) | |||||
| return node_inputs, node_outputs | |||||
| def clear_args_of_declaration(self): | def clear_args_of_declaration(self): | ||||
| """ | """ | ||||
| @@ -57,6 +72,14 @@ class PyTorchGraphNode(GraphNode): | |||||
| """ | """ | ||||
| return f"{arg}_{variable_name}" | return f"{arg}_{variable_name}" | ||||
| @property | |||||
| def is_in_multi_opt_graph(self): | |||||
| return self._is_in_multi_opt_graph | |||||
| @is_in_multi_opt_graph.setter | |||||
| def is_in_multi_opt_graph(self, multi_opt_state): | |||||
| self._is_in_multi_opt_graph = multi_opt_state | |||||
| @property | @property | ||||
| def hash_key(self): | def hash_key(self): | ||||
| """ | """ | ||||
| @@ -119,14 +142,14 @@ class PyTorchGraphNode(GraphNode): | |||||
| self._ipt_shape = input_shape | self._ipt_shape = input_shape | ||||
| self._opt_shape = output_shape | self._opt_shape = output_shape | ||||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): | |||||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment): | |||||
| """ | """ | ||||
| Generate statements. | Generate statements. | ||||
| Args: | Args: | ||||
| variable_name (str): Variable name. | variable_name (str): Variable name. | ||||
| ipt_args_in_construct (str): Args of input. | ipt_args_in_construct (str): Args of input. | ||||
| output_var (str): Output variable name in construct. | |||||
| output_var (list): Output variable names in construct. | |||||
| code_fragment (CodeFragment): CodeFragment instance. | code_fragment (CodeFragment): CodeFragment instance. | ||||
| Returns: | Returns: | ||||
| @@ -157,7 +180,8 @@ class PyTorchGraphNode(GraphNode): | |||||
| operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | ||||
| declare = f"self.{variable_name} = {operator}({expr})" | declare = f"self.{variable_name} = {operator}({expr})" | ||||
| call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" | |||||
| call = f"{', '.join([output for output in output_var])}" \ | |||||
| f" = self.{variable_name}({ipt_args_settings_in_construct})" | |||||
| return declare, call | return declare, call | ||||