From: @moran3 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -44,6 +44,7 @@ class Fragment(abc.ABC): | |||
| operation (str): Operation name in MindSpore. | |||
| actual_args (dict): Actual arg values. | |||
| settings (namedTuple): Code generation setting. | |||
| """ | |||
| def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): | |||
| @@ -89,9 +90,9 @@ class Fragment(abc.ABC): | |||
| self._declared_variable_name = var | |||
| @property | |||
| def output_var_name(self) -> str: | |||
| def output_var_name(self) -> list: | |||
| """Getter of output variable name.""" | |||
| return ", ".join(self._output_var_name) | |||
| return self._output_var_name | |||
| @output_var_name.setter | |||
| def output_var_name(self, opt_vars): | |||
| @@ -100,6 +101,7 @@ class Fragment(abc.ABC): | |||
| Args: | |||
| opt_vars (list[str]): Output variable name. | |||
| """ | |||
| self._output_var_name = opt_vars | |||
| @@ -119,8 +121,9 @@ class Fragment(abc.ABC): | |||
| Args: | |||
| ipt (Fragment): Where input comes from. | |||
| """ | |||
| self._operation_inputs.append(ipt) | |||
| self._operation_inputs += ipt | |||
| @property | |||
| def operation(self): | |||
| @@ -139,6 +142,7 @@ class Fragment(abc.ABC): | |||
| Args: | |||
| op (str): Operation name. | |||
| """ | |||
| self._operation = op | |||
| @@ -158,6 +162,7 @@ class Fragment(abc.ABC): | |||
| Args: | |||
| formal_args (dict): To be updated args. | |||
| """ | |||
| return self._formal_args_list.update(formal_args) | |||
| @@ -194,6 +199,7 @@ class CodeFragment(Fragment): | |||
| operation (str): Operation name in MindSpore. | |||
| actual_args (dict): Actual arg values. | |||
| settings (namedTuple): Code generation setting. | |||
| """ | |||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape, | |||
| @@ -18,6 +18,7 @@ from enum import Enum, unique | |||
| SEPARATOR_IN_ONNX_OP = "::" | |||
| SEPARATOR_IN_SCOPE = "/" | |||
| SEPARATOR_BTW_NAME_AND_ID = "_" | |||
| SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "=" | |||
| LINK_IN_SCOPE = "-" | |||
| LEFT_BUCKET = "[" | |||
| RIGHT_BUCKET = "]" | |||
| @@ -52,6 +53,10 @@ EXPECTED_NUMBER = 1 | |||
| MIN_SCOPE_LENGTH = 2 | |||
| NO_CONVERTED_OPERATORS = [ | |||
| "onnx::Constant" | |||
| ] | |||
| @unique | |||
| class CodeFormatConfig(Enum): | |||
| @@ -29,7 +29,7 @@ from ..common.utils import is_converted, save_code_file_and_report | |||
| from ..mapper.base import Mapper | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..constant import SEPARATOR_IN_SCOPE, get_imported_module | |||
| from ..constant import SEPARATOR_IN_SCOPE, get_imported_module, NO_CONVERTED_OPERATORS | |||
| from ..constant import CodeFormatConfig | |||
| from ..constant import SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVEL_INDENT | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||
| @@ -472,6 +472,8 @@ class HierarchicalTree(Tree): | |||
| for idx, node_name in enumerate(node.successors(self.tree_identifier)): | |||
| nd_inst = self.get_node(node_name) | |||
| if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: | |||
| continue | |||
| # Generate code statement. | |||
| init, construct = self._generate_stat(nd_inst, node, idx) | |||
| @@ -518,14 +520,25 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| ipt_args_in_construct = "x" | |||
| opt_arg_in_construct = "output" | |||
| opt_arg_in_construct = ["output"] | |||
| if idx != 0: | |||
| # Get previous node output variable name. | |||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||
| if cur_nd_inst.data.is_in_multi_opt_graph: | |||
| ipt_args_in_construct = self._get_current_ipt_var(cur_nd_inst) | |||
| else: | |||
| # Get previous node output variable name. | |||
| ipt_args_in_construct = self._get_previous_opt_var(cur_nd_inst, pre_nd_inst) | |||
| if idx != len(pre_nd_inst.successors(self.tree_identifier)) - 1: | |||
| # Set opt variable name. | |||
| opt_arg_in_construct = f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" | |||
| if cur_nd_inst.data.node_type == NodeType.MODULE.value or not cur_nd_inst.data.is_in_multi_opt_graph: | |||
| opt_arg_in_construct = [ | |||
| f"{self.code_fragment_recorder[cur_nd_inst.identifier].declared_var_name}_opt" | |||
| ] | |||
| else: | |||
| opt_arg_in_construct = [ | |||
| f"opt_{var_name}" | |||
| for var_name in self.code_fragment_recorder[cur_nd_inst.identifier].output_var_name | |||
| ] | |||
| declare, call = cur_nd_inst.data.to_code(ipt_args_in_construct=ipt_args_in_construct, | |||
| variable_name=self.code_fragment_recorder[ | |||
| @@ -548,6 +561,39 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] | |||
| def _get_current_ipt_var(self, cur_nd): | |||
| """" | |||
| Get current input variable name from node_id. | |||
| Args: | |||
| cur_nd (Node): Current node. | |||
| Returns: | |||
| str, needed var names. | |||
| """ | |||
| if cur_nd.data.node_type != NodeType.OPERATION.value: | |||
| while True: | |||
| p_nd = cur_nd.successors(self.tree_identifier) | |||
| if not p_nd: | |||
| break | |||
| cur_nd = self.get_node(p_nd[0]) | |||
| ipt_lst_raw = [] | |||
| for operation_input in self.code_fragment_recorder[cur_nd.identifier].operation_inputs: | |||
| ipt_lst_raw.append(f"{operation_input}") | |||
| opt_var_names_p_nds = set() | |||
| for e in cur_nd.data.precursor_nodes: | |||
| p_nd = self.get_node(e) | |||
| if p_nd.data.op_name in NO_CONVERTED_OPERATORS: | |||
| continue | |||
| opt_var_names_p_nd = set(p_nd.data.opt_var_names) | |||
| opt_var_names_p_nds = set.union(opt_var_names_p_nds, opt_var_names_p_nd) | |||
| ipt_lst = [f"opt_{ipt}" for ipt in set(ipt_lst_raw).intersection(opt_var_names_p_nds)] | |||
| return ", ".join(ipt_lst) | |||
| def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | |||
| """ | |||
| Find all input variable names. | |||
| @@ -557,9 +603,12 @@ class HierarchicalTree(Tree): | |||
| pre_nd (Node): Precursor node. | |||
| Returns: | |||
| str, needed var names. | |||
| list, needed var names list. | |||
| """ | |||
| ipt_lst = [] | |||
| if cur_nd.tag in NO_CONVERTED_OPERATORS: | |||
| return ipt_lst | |||
| for e in cur_nd.data.precursor_nodes: | |||
| p_nd = self.get_node(e) | |||
| if e not in pre_nd.successors(self.tree_identifier): | |||
| @@ -575,7 +624,6 @@ class HierarchicalTree(Tree): | |||
| break | |||
| p_nd = self.get_node(pre_nd_name) | |||
| continue | |||
| ipt_lst.append( | |||
| f"{self.code_fragment_recorder[p_nd.identifier].declared_var_name}_opt" | |||
| ) | |||
| @@ -671,6 +719,9 @@ class HierarchicalTree(Tree): | |||
| # Sub-modules in the module could have arg name conflicts. | |||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | |||
| nd_inst = self.get_node(successor_name) | |||
| if nd_inst.data.op_name in NO_CONVERTED_OPERATORS: | |||
| continue | |||
| # Generation of params must behind variable assigment. | |||
| if created: | |||
| variable_name = self._module_vars[module_key][idx] | |||
| @@ -680,6 +731,8 @@ class HierarchicalTree(Tree): | |||
| code_fragment = nd_inst.data.param_transform(mapper, variable_name) | |||
| code_fragment.declared_var_name = variable_name | |||
| code_fragment.output_var_name = nd_inst.data.opt_var_names | |||
| code_fragment.update_operation_inputs(nd_inst.data.ipt_var_names) | |||
| self.code_fragment_recorder[nd_inst.identifier] = code_fragment | |||
| module_args.update(nd_inst.data.args_in_code) | |||
| @@ -34,9 +34,19 @@ class ReshapeMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| if kwargs.get("weights", None): | |||
| return ReshapeMapper._convert_settings_tf(**kwargs) | |||
| return ReshapeMapper._convert_settings_pytorch(**kwargs) | |||
| @staticmethod | |||
| def _convert_settings_pytorch(**kwargs): | |||
| params = kwargs.get("params") | |||
| shape = params.get("output_shape") | |||
| return Setting(op_extra_input={"input_shape": tuple(shape)}) | |||
| @staticmethod | |||
| def _convert_settings_tf(**kwargs): | |||
| weights = kwargs.get("weights") | |||
| if not weights: | |||
| return Setting() | |||
| if len(weights) > 1: | |||
| raise ValueError("For reshape, `weights` length should equal to 1.") | |||
| shape = [-1] | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| from ...gen_setting import Setting | |||
| class SplitMapper(ONNXToMindSporeMapper): | |||
| """Split mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Split" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| axis = kwargs["params"]["axis"] | |||
| split = kwargs["params"]["split"] | |||
| output_num = len(split) | |||
| return {"axis": axis, | |||
| "output_num": output_num} | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return Setting() | |||
| @@ -18,5 +18,6 @@ | |||
| "onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper", | |||
| "onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper", | |||
| "onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper", | |||
| "onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper" | |||
| "onnx::Sigmoid": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.sigmoid_mapper.SigmoidMapper", | |||
| "onnx::Split": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.split_mapper.SplitMapper" | |||
| } | |||
| @@ -105,6 +105,7 @@ class Graph(BaseGraph, abc.ABC): | |||
| self._output_nodes = [] | |||
| self._topological_order = [] | |||
| self._input_shape = dict() | |||
| self._is_multi_opt_graph = False | |||
| def get_input_shape(self, name): | |||
| """ | |||
| @@ -303,11 +304,33 @@ class GraphNode(abc.ABC): | |||
| self._opt_shape = None | |||
| # Weight of current op. | |||
| self._weight = None | |||
| # Input variable names. | |||
| self._ipt_var_names = list() | |||
| # Output variable names. | |||
| self._opt_var_names = list() | |||
| # Is in multi output graph. | |||
| self._is_in_multi_opt_graph = False | |||
| @property | |||
| def weight(self): | |||
| return self._weight | |||
| @property | |||
| def ipt_var_names(self): | |||
| return self._ipt_var_names | |||
| @ipt_var_names.setter | |||
| def ipt_var_names(self, var_names): | |||
| self._ipt_var_names = var_names | |||
| @property | |||
| def opt_var_names(self): | |||
| return self._opt_var_names | |||
| @opt_var_names.setter | |||
| def opt_var_names(self, var_names): | |||
| self._opt_var_names = var_names | |||
| @staticmethod | |||
| def get_opt_var_name(variable_name): | |||
| """ | |||
| @@ -24,7 +24,7 @@ from .pytorch_graph_node import PyTorchGraphNode | |||
| from .pytorch_graph_parser import PyTorchGraphParser | |||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ | |||
| MIN_SCOPE_LENGTH | |||
| MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT | |||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | |||
| NONE_SCOPE_OP = { | |||
| @@ -33,22 +33,32 @@ NONE_SCOPE_OP = { | |||
| "onnx::Concat": "Concat", | |||
| "onnx::Squeeze": "Squeeze", | |||
| "onnx::Unsqueeze": "Unsqueeze", | |||
| "onnx::Split": "Split", | |||
| "onnx::Reshape": "Reshape", | |||
| "onnx::Transpose": "Transpose", | |||
| "onnx::Constant": "Constant", | |||
| "onnx::ReduceMean": "ReduceMean" | |||
| } | |||
| def normalize_scope_name(node): | |||
| def normalize_scope_name(node, scope_name_dict): | |||
| """ | |||
| Rename scope name into uniform. | |||
| Args: | |||
| node (Node): PyTorch node. | |||
| scope_name_dict (dict): Dictionary of scope names with the key node_id. | |||
| Returns: | |||
| str, normalized scope name. | |||
| """ | |||
| global NONE_SCOPE_OP | |||
| name = node.scopeName().replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) | |||
| scope_name = node.scopeName() | |||
| if not scope_name: | |||
| name = [retrieve_scope_name(node, scope_name_dict)] | |||
| else: | |||
| name = scope_name.replace(SEPARATOR_BTW_NAME_AND_ID, '').split(SEPARATOR_IN_SCOPE) | |||
| scopes = [] | |||
| for segment in name: | |||
| segment = segment.split(LINK_IN_SCOPE)[0] | |||
| @@ -64,7 +74,43 @@ def normalize_scope_name(node): | |||
| if node.kind() in NONE_SCOPE_OP.keys(): | |||
| scopes.append(NONE_SCOPE_OP[node.kind()]) | |||
| scopes = [s for s in scopes if s] | |||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}" | |||
| node_id = PyTorchGraph.get_node_id(node) | |||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{'&'.join(node_id)}" | |||
| def retrieve_scope_name(node, scope_name_dict): | |||
| """ | |||
| Retrieve scope name from input nodes. | |||
| Args: | |||
| node (Node): PyTorch node. | |||
| scope_name_dict (dict): Dictionary of scope names with the key node_id. | |||
| Return: | |||
| str: Scope name. | |||
| """ | |||
| node_content = \ | |||
| SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join(str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:]) | |||
| node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] | |||
| node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") | |||
| scope_name_ipt_nodes = list() | |||
| for node_input in node_inputs: | |||
| if not scope_name_dict.get(node_input, None): | |||
| continue | |||
| scope_name_ipt_nodes.append(scope_name_dict[node_input]) | |||
| scope_name_split = list() | |||
| for idx, _ in enumerate(scope_name_ipt_nodes): | |||
| if not scope_name_split: | |||
| scope_name_split = scope_name_ipt_nodes[idx] | |||
| else: | |||
| scope_name_split = [ | |||
| sub_scope_name | |||
| for sub_scope_name in scope_name_split if sub_scope_name in scope_name_ipt_nodes[idx] | |||
| ] | |||
| scope_name = SEPARATOR_IN_SCOPE.join(scope_name_split) | |||
| return scope_name | |||
| class PyTorchGraph(Graph): | |||
| @@ -179,8 +225,12 @@ class PyTorchGraph(Graph): | |||
| graph = self._trace_torch_graph(feed_forward_ipt_shape) | |||
| nodes = list(graph.nodes()) | |||
| scope_name_dict = dict() | |||
| for node in nodes: | |||
| node_name = normalize_scope_name(node) | |||
| node_name = normalize_scope_name(node, scope_name_dict) | |||
| scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \ | |||
| = list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE)) | |||
| output_shape_str_list = re.findall(r'[^()!]+', str(node)) | |||
| output_shape_str = output_shape_str_list[1] | |||
| output_shape = self._extract_shape(output_shape_str) | |||
| @@ -204,7 +254,7 @@ class PyTorchGraph(Graph): | |||
| if nd_id and nd_scope_name: | |||
| node_input_name = normalize_scope_name( | |||
| node_input.node() | |||
| node_input.node(), scope_name_dict | |||
| ) | |||
| self.build_connection(node_input_name, node_name) | |||
| @@ -259,12 +309,16 @@ class PyTorchGraph(Graph): | |||
| return module_dict | |||
| def _check_multi_ipt(self): | |||
| def _check_multi_ipt_opt(self): | |||
| """Check whether multi-input exists.""" | |||
| module_dict = self._generate_module() | |||
| for _, nodes_per_module in module_dict.items(): | |||
| prcs_nodes_out_from_module = set() | |||
| for node_name in nodes_per_module: | |||
| if re.search(r"[\d]+[&][\d]+", node_name): | |||
| self._is_multi_opt_graph = True | |||
| return True | |||
| node = self._nodes_collection.get(node_name, None) | |||
| if node: | |||
| prcs_nodes = node.precursor_nodes | |||
| @@ -284,11 +338,13 @@ class PyTorchGraph(Graph): | |||
| def _unmerge_multi_ipt_opt_script(self): | |||
| """Unmerge all submodule.""" | |||
| if self._check_multi_ipt(): | |||
| if self._check_multi_ipt_opt(): | |||
| for node_key, node_inst in deepcopy(self._nodes_collection).items(): | |||
| prsc_nodes = node_inst.precursor_nodes | |||
| scsr_nodes = node_inst.successor_nodes | |||
| node_inst.is_in_multi_opt_graph = self._is_multi_opt_graph | |||
| node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], | |||
| prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) | |||
| for prsc_node in deepcopy(prsc_nodes)] | |||
| @@ -382,5 +438,6 @@ class PyTorchGraph(Graph): | |||
| Returns: | |||
| str, node id. | |||
| """ | |||
| node_id = re.search(r"[\d]+", str(node)) | |||
| return node_id.group() | |||
| node_title = str(node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] | |||
| node_id = re.findall(r"[%](.*?) [:]", node_title) | |||
| return node_id | |||
| @@ -13,11 +13,13 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph node.""" | |||
| import re | |||
| from .base import GraphNode | |||
| from ..common.utils import is_converted | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| SEPARATOR_IN_ONNX_OP | |||
| SEPARATOR_IN_ONNX_OP, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT | |||
| class PyTorchGraphNode(GraphNode): | |||
| @@ -38,6 +40,19 @@ class PyTorchGraphNode(GraphNode): | |||
| self._op_name = node.kind() if node else None | |||
| self._scope_name = node.scopeName() if node else None | |||
| self._weight = weight | |||
| self._ipt_var_names, self._opt_var_names \ | |||
| = self._extract_ipt_opt_var_names() if node else (list(), list()) | |||
| def _extract_ipt_opt_var_names(self): | |||
| """Extract ipt and opt var names.""" | |||
| node_content = SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT.join( | |||
| str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[1:] | |||
| ) | |||
| node_inputs = re.findall(r"[(](.*?)[)]", node_content)[0] | |||
| node_inputs = re.sub(r"[\s%]", '', node_inputs).split(",") | |||
| node_title = str(self._src_node).split(SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT)[0] | |||
| node_outputs = re.findall(r"[%](.*?) [:]", node_title) | |||
| return node_inputs, node_outputs | |||
| def clear_args_of_declaration(self): | |||
| """ | |||
| @@ -57,6 +72,14 @@ class PyTorchGraphNode(GraphNode): | |||
| """ | |||
| return f"{arg}_{variable_name}" | |||
| @property | |||
| def is_in_multi_opt_graph(self): | |||
| return self._is_in_multi_opt_graph | |||
| @is_in_multi_opt_graph.setter | |||
| def is_in_multi_opt_graph(self, multi_opt_state): | |||
| self._is_in_multi_opt_graph = multi_opt_state | |||
| @property | |||
| def hash_key(self): | |||
| """ | |||
| @@ -119,14 +142,14 @@ class PyTorchGraphNode(GraphNode): | |||
| self._ipt_shape = input_shape | |||
| self._opt_shape = output_shape | |||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: str, code_fragment): | |||
| def to_code(self, ipt_args_in_construct: str, variable_name: str, output_var: list, code_fragment): | |||
| """ | |||
| Generate statements. | |||
| Args: | |||
| variable_name (str): Variable name. | |||
| ipt_args_in_construct (str): Args of input. | |||
| output_var (str): Output variable name in construct. | |||
| output_var (list): Output variable names in construct. | |||
| code_fragment (CodeFragment): CodeFragment instance. | |||
| Returns: | |||
| @@ -157,7 +180,8 @@ class PyTorchGraphNode(GraphNode): | |||
| operator = operator.replace(SEPARATOR_IN_ONNX_OP, ".") | |||
| declare = f"self.{variable_name} = {operator}({expr})" | |||
| call = f"{output_var} = self.{variable_name}({ipt_args_settings_in_construct})" | |||
| call = f"{', '.join([output for output in output_var])}" \ | |||
| f" = self.{variable_name}({ipt_args_settings_in_construct})" | |||
| return declare, call | |||