diff --git a/mindinsight/mindconverter/graph_based_converter/common/global_context.py b/mindinsight/mindconverter/graph_based_converter/common/global_context.py index 64a723d8..0ab74cf4 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/global_context.py +++ b/mindinsight/mindconverter/graph_based_converter/common/global_context.py @@ -26,6 +26,11 @@ class Singleton(type): cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] + @classmethod + def release(mcs): + """Clear singleton object.""" + mcs._instances.clear() + class GlobalContext(metaclass=Singleton): """ @@ -110,7 +115,7 @@ class GlobalContext(metaclass=Singleton): if isinstance(arg, OrderedDict): self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader else: - raise TypeError("GlobalContext received an unsupport variable to assign to onnx_nodes_collection.") + raise TypeError("GlobalContext received an unsupported variable to assign to onnx_nodes_collection.") @property def onnx_nodes_topo_index(self) -> dict: @@ -149,7 +154,7 @@ class GlobalContext(metaclass=Singleton): if isinstance(arg, dict): self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader else: - raise TypeError("GlobalContext received an unsupport variable to assign to onnx_tensors_collection.") + raise TypeError("GlobalContext received an unsupported variable to assign to onnx_tensors_collection.") @property def latest_node_struct_count(self): @@ -237,3 +242,8 @@ class GlobalContext(metaclass=Singleton): self.module_structs[pattern_id] = [module_struct] else: self.module_structs[pattern_id].append(module_struct) + + @classmethod + def release(cls): + """Clear singleton object.""" + Singleton.release() diff --git a/mindinsight/mindconverter/graph_based_converter/common/yapf_config.py b/mindinsight/mindconverter/graph_based_converter/common/yapf_config.py new file mode 100644 index 00000000..ce27ce07 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/yapf_config.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define code format configuration.""" +from yapf.yapflib.style import CreatePEP8Style + + +def mindspore_yapf_config(): + """Create the PEP8 formatting style.""" + style = CreatePEP8Style() + style['ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS'] = False + style['ALLOW_MULTILINE_LAMBDAS'] = True + style['ALLOW_SPLIT_BEFORE_DICT_VALUE'] = False + style['COLUMN_LIMIT'] = 120 + style['COALESCE_BRACKETS'] = True + style['FORCE_MULTILINE_DICT'] = True + style['DISABLE_ENDING_COMMA_HEURISTIC'] = True + style['INDENT_DICTIONARY_VALUE'] = True + style['JOIN_MULTIPLE_LINES'] = False + style['SPACES_BEFORE_COMMENT'] = 2 + style['SPLIT_PENALTY_AFTER_OPENING_BRACKET'] = 30 + style['SPLIT_PENALTY_BEFORE_IF_EXPR'] = 30 + style['SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT'] = 30 + style['SPLIT_BEFORE_LOGICAL_OPERATOR'] = False + style['SPLIT_BEFORE_BITWISE_OPERATOR'] = False + return style diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 9041a940..7f138b2e 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -104,11 +104,6 @@ NO_CONVERTED_OPERATORS = [ ] -@unique -class CodeFormatConfig(Enum): - PEP8 = "pep8" - - @unique class NodeType(Enum): MODULE = "module" diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 882b825d..aaeb0d99 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -20,6 +20,7 @@ from importlib import import_module from importlib.util import find_spec import mindinsight +from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ save_code_file_and_report, get_framework_type from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ @@ -199,13 +200,14 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, output_folder (str): Output folder. report_folder (str): Report output folder path. """ - graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, input_nodes=input_nodes, output_nodes=output_nodes) generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) model_name = _extract_model_name(graph_path) code_fragments = generator_inst.generate() save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) + # Release global context. + GlobalContext.release() @tf_installation_validation @@ -238,6 +240,8 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, model_name = _extract_model_name(graph_path) code_fragments = generator_inst.generate() save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) + # Release global context. + GlobalContext.release() @BaseConverterError.uniform_catcher() diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py index 42fdce76..3fe638ca 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/generator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -18,16 +18,18 @@ from collections import OrderedDict from yapf.yapflib.yapf_api import FormatCode -from .scope_utils import Scope -from .node_struct import NodeStruct -from .module_struct import ModuleStruct -from .args_translator import ArgsTranslationHelper -from ..common.global_context import GlobalContext -from ..common.outputs import BaseOutput, ModuleOutputManager -from ...common.exceptions import GeneratorError -from ..common.name_mgr import GlobalVarNameMgr -from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module -from ..report_generator import ReportGenerator +from mindinsight.mindconverter.common.exceptions import GeneratorError +from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope +from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct +from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct +from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslationHelper +from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext +from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseOutput, ModuleOutputManager +from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config +from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr +from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, \ + get_imported_module +from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator class CodeStruct: @@ -35,7 +37,6 @@ class CodeStruct: Define the Code template for each module generated in the final output. Each module has only one CodeStruct to its pattern. """ - GLOBAL_CONTEXT = GlobalContext() NOT_IN_SCOPE_OPT = dict() def __init__(self, struct, repeated_submodules=None): @@ -104,7 +105,7 @@ class CodeStruct: elif isinstance(struct, ModuleStruct): # check if this instance generated CodeStruct - if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None: + if GlobalContext().code_structs.get(struct.pattern_id) is None: CodeStruct(struct, repeated_submodules) code_line_init = struct.code_line_in_init() @@ -138,10 +139,10 @@ class CodeStruct: returns = list(set(returns)) else: returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \ - else [code_line_construct[-1].replace(' ', '').split('=')[0]] + else [code_line_construct[-1].replace(' ', '').split('=')[0]] self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" self.new_line = f"{NEW_LINE * 2}" - self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self + GlobalContext().code_structs[md_struct.pattern_id] = self class Generator: @@ -482,7 +483,7 @@ class Generator: outputs.append(line) formatted_code, _ = FormatCode("\n".join(outputs), - style_config=CodeFormatConfig.PEP8.value) + style_config=mindspore_yapf_config()) report_generator = ReportGenerator() report = report_generator.gen_report(formatted_code) @@ -589,7 +590,7 @@ class Generator: output_obj.idx_in_ms_user[nd_struct.identifier] = idx # set this output to be returned to external - output_obj.to_external = not(nd_struct.check_target_node_internal( + output_obj.to_external = not (nd_struct.check_target_node_internal( self._global_context.outputs_storage.onnx_name(inp) )) diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py index 4f661da2..22ae05ed 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -31,11 +31,10 @@ class ModuleStruct: Define a module struct which stores all info. to generate statement. Args: - args (list): A list of node structs. + nd_struct_list (list): A list of node structs. init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct. parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as. """ - GLOBAL_CONTEXT_MGR = GlobalContext() def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): """Init. a module by NodeStructs.""" @@ -247,7 +246,7 @@ class ModuleStruct: self._module_structs += md_structs tail_md = md_structs[-1] else: - raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format( + raise TypeError("ModuleStruct cannot add an unsupported Type {} to module_structs list.".format( type(md_structs))) # update tail node and index if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: @@ -318,12 +317,7 @@ class ModuleStruct: return ret def code_line_in_init(self): - """ - Initialization line of code in module init block. - - Args: - override_formal_val (dict): Indicate which args should be renamed for passing value from upper level. - """ + """Initialization line of code in module init block.""" left = "self.{}".format(self.ms_var_name) args_list = list() # Load args in init statement. @@ -338,7 +332,7 @@ class ModuleStruct: else: args_list += self._fragment.actual_args right = f"{self.class_name}({', '.join(args_list)})" - return (left, right) + return left, right def code_line_in_construct(self, inputs=None): """Construct line of code in module construct block.""" @@ -356,7 +350,7 @@ class ModuleStruct: if isinstance(inputs, str): inputs = [inputs] right = f"self.{self.ms_var_name}({', '.join(inputs)})" - return (left, right) + return left, right @property def node_structs(self): @@ -463,8 +457,8 @@ class ModuleStruct: """Return the class name for generating code of this module.""" if self.pattern_id == -1: return "Model" - if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None: - class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) + if GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) is not None: + class_name = GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) else: class_name = "Module{}".format(self.pattern_id) return class_name diff --git a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py index ed647f5c..1b3c0e01 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py @@ -23,6 +23,7 @@ from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..common.global_context import GlobalContext from ...common.exceptions import GeneratorError + class NodeStruct: """ Define a node struct which stores all info. to generate statement. @@ -34,10 +35,10 @@ class NodeStruct: You can pass as many args as possible and the Node Struct will update by arguments order. """ - GLOBAL_CONTEXT_MGR = GlobalContext() def __init__(self, args): # define attributes here + self.global_context_mgr = GlobalContext() self._identifier = None self._fragment = None self._args_translator = None @@ -74,7 +75,7 @@ class NodeStruct: """Get the original topological index in the onnx graph.""" ori_name = self._fragment.metadata.get('source') self.onnx_name = ori_name - return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) + return GlobalContext().onnx_node_name_to_topo_idx.get(ori_name) def update_var_name(self, idx=None): """ @@ -83,6 +84,7 @@ class NodeStruct: Args: idx (int): The index of the node in this module. """ + def _remove_op_header(op_name): """Remove op header which indicating their sources of op set.""" op_name = op_name.replace('nn.', '') @@ -112,7 +114,7 @@ class NodeStruct: self._fragment = FragmentHandler(frag) if self.ms_op: - idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count + idx = GlobalContext().latest_node_struct_count self.update_var_name(idx=idx) def _set_scope_from_identifier(self): @@ -142,9 +144,7 @@ class NodeStruct: Args: arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. - force_ready (bool): Force this NodeStruct is ready to generate. """ - if isinstance(arg, OnnxGraphNode): self._update_from_onnx_gn(arg) elif isinstance(arg, NewFragment): @@ -168,7 +168,7 @@ class NodeStruct: self._identifier = s self._set_scope_from_identifier() self.topo_idx = self.ori_topo_idx() - self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self + GlobalContext().onnx_node_name_to_node_struct_map[self.onnx_name] = self @property def fragment(self): @@ -181,7 +181,7 @@ class NodeStruct: Set the Node fragment. Args: - s (NodeFragment): The node identifier string. + frag (NodeFragment): The node identifier string. """ self._fragment = frag @@ -198,7 +198,7 @@ class NodeStruct: @property def onnx_node(self): """Return the original onnx node reference.""" - return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) + return GlobalContext().onnx_nodes_collection.get(self.onnx_name) @property def ms_op(self): @@ -241,7 +241,7 @@ class NodeStruct: ret = [] precursor_nodes_names = self.precursor_nodes_names for pre_node_name in precursor_nodes_names: - nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) + nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name) ret.append(nd_struct) return ret @@ -255,7 +255,7 @@ class NodeStruct: """Return the node struct instances of successor nodes.""" ret = [] for pre_node_name in self.successor_nodes_names: - nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) + nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name) ret.append(nd_struct) return ret @@ -312,11 +312,11 @@ class NodeStruct: inputs = self.matched_inputs # Check original onnx node's input to ensure double inputs are not ignored - original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name) + original_inputs = GlobalContext().onnx_node_inputs.get(self.onnx_name) new_inputs = [] for idx, prec_node in enumerate(self.precursor_nodes_names): - occurence = original_inputs.count(prec_node) - for _ in range(occurence): + occurrence = original_inputs.count(prec_node) + for _ in range(occurrence): new_inputs.append(inputs[idx]) inputs = new_inputs @@ -360,12 +360,12 @@ class NodeStruct: Args: name (str): Can accept both node identifier or original onnx node name. """ - target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ - or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) + target_nd_struct = GlobalContext().node_struct_collections.get(name) \ + or GlobalContext().onnx_node_name_to_node_struct_map.get(name) if target_nd_struct is None and self.topo_idx == 0: # First node always has external input return False - if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')): + if target_nd_struct is None and (name in GlobalContext().onnx_graph_info.get('graph_inputs')): return False if target_nd_struct is None: diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py index 27ca4426..125c3875 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py @@ -34,7 +34,7 @@ class Pattern: # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, # the pattern will get additional score. self.additional_score = 0 - self.know_module_name = None + self.known_module_name = None def insert(self, idx, seq_len): """