| @@ -26,6 +26,11 @@ class Singleton(type): | |||||
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | ||||
| return cls._instances[cls] | return cls._instances[cls] | ||||
| @classmethod | |||||
| def release(mcs): | |||||
| """Clear singleton object.""" | |||||
| mcs._instances.clear() | |||||
| class GlobalContext(metaclass=Singleton): | class GlobalContext(metaclass=Singleton): | ||||
| """ | """ | ||||
| @@ -110,7 +115,7 @@ class GlobalContext(metaclass=Singleton): | |||||
| if isinstance(arg, OrderedDict): | if isinstance(arg, OrderedDict): | ||||
| self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader | self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader | ||||
| else: | else: | ||||
| raise TypeError("GlobalContext received an unsupport variable to assign to onnx_nodes_collection.") | |||||
| raise TypeError("GlobalContext received an unsupported variable to assign to onnx_nodes_collection.") | |||||
| @property | @property | ||||
| def onnx_nodes_topo_index(self) -> dict: | def onnx_nodes_topo_index(self) -> dict: | ||||
| @@ -149,7 +154,7 @@ class GlobalContext(metaclass=Singleton): | |||||
| if isinstance(arg, dict): | if isinstance(arg, dict): | ||||
| self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader | self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader | ||||
| else: | else: | ||||
| raise TypeError("GlobalContext received an unsupport variable to assign to onnx_tensors_collection.") | |||||
| raise TypeError("GlobalContext received an unsupported variable to assign to onnx_tensors_collection.") | |||||
| @property | @property | ||||
| def latest_node_struct_count(self): | def latest_node_struct_count(self): | ||||
| @@ -237,3 +242,8 @@ class GlobalContext(metaclass=Singleton): | |||||
| self.module_structs[pattern_id] = [module_struct] | self.module_structs[pattern_id] = [module_struct] | ||||
| else: | else: | ||||
| self.module_structs[pattern_id].append(module_struct) | self.module_structs[pattern_id].append(module_struct) | ||||
| @classmethod | |||||
| def release(cls): | |||||
| """Clear singleton object.""" | |||||
| Singleton.release() | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define code format configuration.""" | |||||
| from yapf.yapflib.style import CreatePEP8Style | |||||
| def mindspore_yapf_config(): | |||||
| """Create the PEP8 formatting style.""" | |||||
| style = CreatePEP8Style() | |||||
| style['ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS'] = False | |||||
| style['ALLOW_MULTILINE_LAMBDAS'] = True | |||||
| style['ALLOW_SPLIT_BEFORE_DICT_VALUE'] = False | |||||
| style['COLUMN_LIMIT'] = 120 | |||||
| style['COALESCE_BRACKETS'] = True | |||||
| style['FORCE_MULTILINE_DICT'] = True | |||||
| style['DISABLE_ENDING_COMMA_HEURISTIC'] = True | |||||
| style['INDENT_DICTIONARY_VALUE'] = True | |||||
| style['JOIN_MULTIPLE_LINES'] = False | |||||
| style['SPACES_BEFORE_COMMENT'] = 2 | |||||
| style['SPLIT_PENALTY_AFTER_OPENING_BRACKET'] = 30 | |||||
| style['SPLIT_PENALTY_BEFORE_IF_EXPR'] = 30 | |||||
| style['SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT'] = 30 | |||||
| style['SPLIT_BEFORE_LOGICAL_OPERATOR'] = False | |||||
| style['SPLIT_BEFORE_BITWISE_OPERATOR'] = False | |||||
| return style | |||||
| @@ -104,11 +104,6 @@ NO_CONVERTED_OPERATORS = [ | |||||
| ] | ] | ||||
| @unique | |||||
| class CodeFormatConfig(Enum): | |||||
| PEP8 = "pep8" | |||||
| @unique | @unique | ||||
| class NodeType(Enum): | class NodeType(Enum): | ||||
| MODULE = "module" | MODULE = "module" | ||||
| @@ -20,6 +20,7 @@ from importlib import import_module | |||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | ||||
| save_code_file_and_report, get_framework_type | save_code_file_and_report, get_framework_type | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | ||||
| @@ -199,13 +200,14 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||||
| output_folder (str): Output folder. | output_folder (str): Output folder. | ||||
| report_folder (str): Report output folder path. | report_folder (str): Report output folder path. | ||||
| """ | """ | ||||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | ||||
| input_nodes=input_nodes, output_nodes=output_nodes) | input_nodes=input_nodes, output_nodes=output_nodes) | ||||
| generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | ||||
| model_name = _extract_model_name(graph_path) | model_name = _extract_model_name(graph_path) | ||||
| code_fragments = generator_inst.generate() | code_fragments = generator_inst.generate() | ||||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | ||||
| # Release global context. | |||||
| GlobalContext.release() | |||||
| @tf_installation_validation | @tf_installation_validation | ||||
| @@ -238,6 +240,8 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||||
| model_name = _extract_model_name(graph_path) | model_name = _extract_model_name(graph_path) | ||||
| code_fragments = generator_inst.generate() | code_fragments = generator_inst.generate() | ||||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | ||||
| # Release global context. | |||||
| GlobalContext.release() | |||||
| @BaseConverterError.uniform_catcher() | @BaseConverterError.uniform_catcher() | ||||
| @@ -18,16 +18,18 @@ from collections import OrderedDict | |||||
| from yapf.yapflib.yapf_api import FormatCode | from yapf.yapflib.yapf_api import FormatCode | ||||
| from .scope_utils import Scope | |||||
| from .node_struct import NodeStruct | |||||
| from .module_struct import ModuleStruct | |||||
| from .args_translator import ArgsTranslationHelper | |||||
| from ..common.global_context import GlobalContext | |||||
| from ..common.outputs import BaseOutput, ModuleOutputManager | |||||
| from ...common.exceptions import GeneratorError | |||||
| from ..common.name_mgr import GlobalVarNameMgr | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | |||||
| from ..report_generator import ReportGenerator | |||||
| from mindinsight.mindconverter.common.exceptions import GeneratorError | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslationHelper | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||||
| from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseOutput, ModuleOutputManager | |||||
| from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config | |||||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, \ | |||||
| get_imported_module | |||||
| from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | |||||
| class CodeStruct: | class CodeStruct: | ||||
| @@ -35,7 +37,6 @@ class CodeStruct: | |||||
| Define the Code template for each module generated in the final output. | Define the Code template for each module generated in the final output. | ||||
| Each module has only one CodeStruct to its pattern. | Each module has only one CodeStruct to its pattern. | ||||
| """ | """ | ||||
| GLOBAL_CONTEXT = GlobalContext() | |||||
| NOT_IN_SCOPE_OPT = dict() | NOT_IN_SCOPE_OPT = dict() | ||||
| def __init__(self, struct, repeated_submodules=None): | def __init__(self, struct, repeated_submodules=None): | ||||
| @@ -104,7 +105,7 @@ class CodeStruct: | |||||
| elif isinstance(struct, ModuleStruct): | elif isinstance(struct, ModuleStruct): | ||||
| # check if this instance generated CodeStruct | # check if this instance generated CodeStruct | ||||
| if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None: | |||||
| if GlobalContext().code_structs.get(struct.pattern_id) is None: | |||||
| CodeStruct(struct, repeated_submodules) | CodeStruct(struct, repeated_submodules) | ||||
| code_line_init = struct.code_line_in_init() | code_line_init = struct.code_line_in_init() | ||||
| @@ -138,10 +139,10 @@ class CodeStruct: | |||||
| returns = list(set(returns)) | returns = list(set(returns)) | ||||
| else: | else: | ||||
| returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \ | returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \ | ||||
| else [code_line_construct[-1].replace(' ', '').split('=')[0]] | |||||
| else [code_line_construct[-1].replace(' ', '').split('=')[0]] | |||||
| self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" | self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" | ||||
| self.new_line = f"{NEW_LINE * 2}" | self.new_line = f"{NEW_LINE * 2}" | ||||
| self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self | |||||
| GlobalContext().code_structs[md_struct.pattern_id] = self | |||||
| class Generator: | class Generator: | ||||
| @@ -482,7 +483,7 @@ class Generator: | |||||
| outputs.append(line) | outputs.append(line) | ||||
| formatted_code, _ = FormatCode("\n".join(outputs), | formatted_code, _ = FormatCode("\n".join(outputs), | ||||
| style_config=CodeFormatConfig.PEP8.value) | |||||
| style_config=mindspore_yapf_config()) | |||||
| report_generator = ReportGenerator() | report_generator = ReportGenerator() | ||||
| report = report_generator.gen_report(formatted_code) | report = report_generator.gen_report(formatted_code) | ||||
| @@ -589,7 +590,7 @@ class Generator: | |||||
| output_obj.idx_in_ms_user[nd_struct.identifier] = idx | output_obj.idx_in_ms_user[nd_struct.identifier] = idx | ||||
| # set this output to be returned to external | # set this output to be returned to external | ||||
| output_obj.to_external = not(nd_struct.check_target_node_internal( | |||||
| output_obj.to_external = not (nd_struct.check_target_node_internal( | |||||
| self._global_context.outputs_storage.onnx_name(inp) | self._global_context.outputs_storage.onnx_name(inp) | ||||
| )) | )) | ||||
| @@ -31,11 +31,10 @@ class ModuleStruct: | |||||
| Define a module struct which stores all info. to generate statement. | Define a module struct which stores all info. to generate statement. | ||||
| Args: | Args: | ||||
| args (list): A list of node structs. | |||||
| nd_struct_list (list): A list of node structs. | |||||
| init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct. | init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct. | ||||
| parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as. | parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as. | ||||
| """ | """ | ||||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||||
| def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): | def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): | ||||
| """Init. a module by NodeStructs.""" | """Init. a module by NodeStructs.""" | ||||
| @@ -247,7 +246,7 @@ class ModuleStruct: | |||||
| self._module_structs += md_structs | self._module_structs += md_structs | ||||
| tail_md = md_structs[-1] | tail_md = md_structs[-1] | ||||
| else: | else: | ||||
| raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format( | |||||
| raise TypeError("ModuleStruct cannot add an unsupported Type {} to module_structs list.".format( | |||||
| type(md_structs))) | type(md_structs))) | ||||
| # update tail node and index | # update tail node and index | ||||
| if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: | if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: | ||||
| @@ -318,12 +317,7 @@ class ModuleStruct: | |||||
| return ret | return ret | ||||
| def code_line_in_init(self): | def code_line_in_init(self): | ||||
| """ | |||||
| Initialization line of code in module init block. | |||||
| Args: | |||||
| override_formal_val (dict): Indicate which args should be renamed for passing value from upper level. | |||||
| """ | |||||
| """Initialization line of code in module init block.""" | |||||
| left = "self.{}".format(self.ms_var_name) | left = "self.{}".format(self.ms_var_name) | ||||
| args_list = list() | args_list = list() | ||||
| # Load args in init statement. | # Load args in init statement. | ||||
| @@ -338,7 +332,7 @@ class ModuleStruct: | |||||
| else: | else: | ||||
| args_list += self._fragment.actual_args | args_list += self._fragment.actual_args | ||||
| right = f"{self.class_name}({', '.join(args_list)})" | right = f"{self.class_name}({', '.join(args_list)})" | ||||
| return (left, right) | |||||
| return left, right | |||||
| def code_line_in_construct(self, inputs=None): | def code_line_in_construct(self, inputs=None): | ||||
| """Construct line of code in module construct block.""" | """Construct line of code in module construct block.""" | ||||
| @@ -356,7 +350,7 @@ class ModuleStruct: | |||||
| if isinstance(inputs, str): | if isinstance(inputs, str): | ||||
| inputs = [inputs] | inputs = [inputs] | ||||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | right = f"self.{self.ms_var_name}({', '.join(inputs)})" | ||||
| return (left, right) | |||||
| return left, right | |||||
| @property | @property | ||||
| def node_structs(self): | def node_structs(self): | ||||
| @@ -463,8 +457,8 @@ class ModuleStruct: | |||||
| """Return the class name for generating code of this module.""" | """Return the class name for generating code of this module.""" | ||||
| if self.pattern_id == -1: | if self.pattern_id == -1: | ||||
| return "Model" | return "Model" | ||||
| if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None: | |||||
| class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) | |||||
| if GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) is not None: | |||||
| class_name = GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) | |||||
| else: | else: | ||||
| class_name = "Module{}".format(self.pattern_id) | class_name = "Module{}".format(self.pattern_id) | ||||
| return class_name | return class_name | ||||
| @@ -23,6 +23,7 @@ from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||||
| from ..common.global_context import GlobalContext | from ..common.global_context import GlobalContext | ||||
| from ...common.exceptions import GeneratorError | from ...common.exceptions import GeneratorError | ||||
| class NodeStruct: | class NodeStruct: | ||||
| """ | """ | ||||
| Define a node struct which stores all info. to generate statement. | Define a node struct which stores all info. to generate statement. | ||||
| @@ -34,10 +35,10 @@ class NodeStruct: | |||||
| You can pass as many args as possible and the Node Struct will update | You can pass as many args as possible and the Node Struct will update | ||||
| by arguments order. | by arguments order. | ||||
| """ | """ | ||||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||||
| def __init__(self, args): | def __init__(self, args): | ||||
| # define attributes here | # define attributes here | ||||
| self.global_context_mgr = GlobalContext() | |||||
| self._identifier = None | self._identifier = None | ||||
| self._fragment = None | self._fragment = None | ||||
| self._args_translator = None | self._args_translator = None | ||||
| @@ -74,7 +75,7 @@ class NodeStruct: | |||||
| """Get the original topological index in the onnx graph.""" | """Get the original topological index in the onnx graph.""" | ||||
| ori_name = self._fragment.metadata.get('source') | ori_name = self._fragment.metadata.get('source') | ||||
| self.onnx_name = ori_name | self.onnx_name = ori_name | ||||
| return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) | |||||
| return GlobalContext().onnx_node_name_to_topo_idx.get(ori_name) | |||||
| def update_var_name(self, idx=None): | def update_var_name(self, idx=None): | ||||
| """ | """ | ||||
| @@ -83,6 +84,7 @@ class NodeStruct: | |||||
| Args: | Args: | ||||
| idx (int): The index of the node in this module. | idx (int): The index of the node in this module. | ||||
| """ | """ | ||||
| def _remove_op_header(op_name): | def _remove_op_header(op_name): | ||||
| """Remove op header which indicating their sources of op set.""" | """Remove op header which indicating their sources of op set.""" | ||||
| op_name = op_name.replace('nn.', '') | op_name = op_name.replace('nn.', '') | ||||
| @@ -112,7 +114,7 @@ class NodeStruct: | |||||
| self._fragment = FragmentHandler(frag) | self._fragment = FragmentHandler(frag) | ||||
| if self.ms_op: | if self.ms_op: | ||||
| idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count | |||||
| idx = GlobalContext().latest_node_struct_count | |||||
| self.update_var_name(idx=idx) | self.update_var_name(idx=idx) | ||||
| def _set_scope_from_identifier(self): | def _set_scope_from_identifier(self): | ||||
| @@ -142,9 +144,7 @@ class NodeStruct: | |||||
| Args: | Args: | ||||
| arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | ||||
| force_ready (bool): Force this NodeStruct is ready to generate. | |||||
| """ | """ | ||||
| if isinstance(arg, OnnxGraphNode): | if isinstance(arg, OnnxGraphNode): | ||||
| self._update_from_onnx_gn(arg) | self._update_from_onnx_gn(arg) | ||||
| elif isinstance(arg, NewFragment): | elif isinstance(arg, NewFragment): | ||||
| @@ -168,7 +168,7 @@ class NodeStruct: | |||||
| self._identifier = s | self._identifier = s | ||||
| self._set_scope_from_identifier() | self._set_scope_from_identifier() | ||||
| self.topo_idx = self.ori_topo_idx() | self.topo_idx = self.ori_topo_idx() | ||||
| self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self | |||||
| GlobalContext().onnx_node_name_to_node_struct_map[self.onnx_name] = self | |||||
| @property | @property | ||||
| def fragment(self): | def fragment(self): | ||||
| @@ -181,7 +181,7 @@ class NodeStruct: | |||||
| Set the Node fragment. | Set the Node fragment. | ||||
| Args: | Args: | ||||
| s (NodeFragment): The node identifier string. | |||||
| frag (NodeFragment): The node identifier string. | |||||
| """ | """ | ||||
| self._fragment = frag | self._fragment = frag | ||||
| @@ -198,7 +198,7 @@ class NodeStruct: | |||||
| @property | @property | ||||
| def onnx_node(self): | def onnx_node(self): | ||||
| """Return the original onnx node reference.""" | """Return the original onnx node reference.""" | ||||
| return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) | |||||
| return GlobalContext().onnx_nodes_collection.get(self.onnx_name) | |||||
| @property | @property | ||||
| def ms_op(self): | def ms_op(self): | ||||
| @@ -241,7 +241,7 @@ class NodeStruct: | |||||
| ret = [] | ret = [] | ||||
| precursor_nodes_names = self.precursor_nodes_names | precursor_nodes_names = self.precursor_nodes_names | ||||
| for pre_node_name in precursor_nodes_names: | for pre_node_name in precursor_nodes_names: | ||||
| nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||||
| nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name) | |||||
| ret.append(nd_struct) | ret.append(nd_struct) | ||||
| return ret | return ret | ||||
| @@ -255,7 +255,7 @@ class NodeStruct: | |||||
| """Return the node struct instances of successor nodes.""" | """Return the node struct instances of successor nodes.""" | ||||
| ret = [] | ret = [] | ||||
| for pre_node_name in self.successor_nodes_names: | for pre_node_name in self.successor_nodes_names: | ||||
| nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||||
| nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name) | |||||
| ret.append(nd_struct) | ret.append(nd_struct) | ||||
| return ret | return ret | ||||
| @@ -312,11 +312,11 @@ class NodeStruct: | |||||
| inputs = self.matched_inputs | inputs = self.matched_inputs | ||||
| # Check original onnx node's input to ensure double inputs are not ignored | # Check original onnx node's input to ensure double inputs are not ignored | ||||
| original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name) | |||||
| original_inputs = GlobalContext().onnx_node_inputs.get(self.onnx_name) | |||||
| new_inputs = [] | new_inputs = [] | ||||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | for idx, prec_node in enumerate(self.precursor_nodes_names): | ||||
| occurence = original_inputs.count(prec_node) | |||||
| for _ in range(occurence): | |||||
| occurrence = original_inputs.count(prec_node) | |||||
| for _ in range(occurrence): | |||||
| new_inputs.append(inputs[idx]) | new_inputs.append(inputs[idx]) | ||||
| inputs = new_inputs | inputs = new_inputs | ||||
| @@ -360,12 +360,12 @@ class NodeStruct: | |||||
| Args: | Args: | ||||
| name (str): Can accept both node identifier or original onnx node name. | name (str): Can accept both node identifier or original onnx node name. | ||||
| """ | """ | ||||
| target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ | |||||
| or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) | |||||
| target_nd_struct = GlobalContext().node_struct_collections.get(name) \ | |||||
| or GlobalContext().onnx_node_name_to_node_struct_map.get(name) | |||||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | ||||
| return False | return False | ||||
| if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')): | |||||
| if target_nd_struct is None and (name in GlobalContext().onnx_graph_info.get('graph_inputs')): | |||||
| return False | return False | ||||
| if target_nd_struct is None: | if target_nd_struct is None: | ||||
| @@ -34,7 +34,7 @@ class Pattern: | |||||
| # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, | # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, | ||||
| # the pattern will get additional score. | # the pattern will get additional score. | ||||
| self.additional_score = 0 | self.additional_score = 0 | ||||
| self.know_module_name = None | |||||
| self.known_module_name = None | |||||
| def insert(self, idx, seq_len): | def insert(self, idx, seq_len): | ||||
| """ | """ | ||||