| @@ -26,6 +26,11 @@ class Singleton(type): | |||
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||
| return cls._instances[cls] | |||
| @classmethod | |||
| def release(mcs): | |||
| """Clear singleton object.""" | |||
| mcs._instances.clear() | |||
| class GlobalContext(metaclass=Singleton): | |||
| """ | |||
| @@ -110,7 +115,7 @@ class GlobalContext(metaclass=Singleton): | |||
| if isinstance(arg, OrderedDict): | |||
| self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader | |||
| else: | |||
| raise TypeError("GlobalContext received an unsupport variable to assign to onnx_nodes_collection.") | |||
| raise TypeError("GlobalContext received an unsupported variable to assign to onnx_nodes_collection.") | |||
| @property | |||
| def onnx_nodes_topo_index(self) -> dict: | |||
| @@ -149,7 +154,7 @@ class GlobalContext(metaclass=Singleton): | |||
| if isinstance(arg, dict): | |||
| self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader | |||
| else: | |||
| raise TypeError("GlobalContext received an unsupport variable to assign to onnx_tensors_collection.") | |||
| raise TypeError("GlobalContext received an unsupported variable to assign to onnx_tensors_collection.") | |||
| @property | |||
| def latest_node_struct_count(self): | |||
| @@ -237,3 +242,8 @@ class GlobalContext(metaclass=Singleton): | |||
| self.module_structs[pattern_id] = [module_struct] | |||
| else: | |||
| self.module_structs[pattern_id].append(module_struct) | |||
| @classmethod | |||
| def release(cls): | |||
| """Clear singleton object.""" | |||
| Singleton.release() | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define code format configuration.""" | |||
| from yapf.yapflib.style import CreatePEP8Style | |||
| def mindspore_yapf_config(): | |||
| """Create the PEP8 formatting style.""" | |||
| style = CreatePEP8Style() | |||
| style['ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS'] = False | |||
| style['ALLOW_MULTILINE_LAMBDAS'] = True | |||
| style['ALLOW_SPLIT_BEFORE_DICT_VALUE'] = False | |||
| style['COLUMN_LIMIT'] = 120 | |||
| style['COALESCE_BRACKETS'] = True | |||
| style['FORCE_MULTILINE_DICT'] = True | |||
| style['DISABLE_ENDING_COMMA_HEURISTIC'] = True | |||
| style['INDENT_DICTIONARY_VALUE'] = True | |||
| style['JOIN_MULTIPLE_LINES'] = False | |||
| style['SPACES_BEFORE_COMMENT'] = 2 | |||
| style['SPLIT_PENALTY_AFTER_OPENING_BRACKET'] = 30 | |||
| style['SPLIT_PENALTY_BEFORE_IF_EXPR'] = 30 | |||
| style['SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT'] = 30 | |||
| style['SPLIT_BEFORE_LOGICAL_OPERATOR'] = False | |||
| style['SPLIT_BEFORE_BITWISE_OPERATOR'] = False | |||
| return style | |||
| @@ -104,11 +104,6 @@ NO_CONVERTED_OPERATORS = [ | |||
| ] | |||
| @unique | |||
| class CodeFormatConfig(Enum): | |||
| PEP8 = "pep8" | |||
| @unique | |||
| class NodeType(Enum): | |||
| MODULE = "module" | |||
| @@ -20,6 +20,7 @@ from importlib import import_module | |||
| from importlib.util import find_spec | |||
| import mindinsight | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | |||
| save_code_file_and_report, get_framework_type | |||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | |||
| @@ -199,13 +200,14 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| output_folder (str): Output folder. | |||
| report_folder (str): Report output folder path. | |||
| """ | |||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||
| generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) | |||
| model_name = _extract_model_name(graph_path) | |||
| code_fragments = generator_inst.generate() | |||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | |||
| # Release global context. | |||
| GlobalContext.release() | |||
| @tf_installation_validation | |||
| @@ -238,6 +240,8 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| model_name = _extract_model_name(graph_path) | |||
| code_fragments = generator_inst.generate() | |||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | |||
| # Release global context. | |||
| GlobalContext.release() | |||
| @BaseConverterError.uniform_catcher() | |||
| @@ -18,16 +18,18 @@ from collections import OrderedDict | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from .scope_utils import Scope | |||
| from .node_struct import NodeStruct | |||
| from .module_struct import ModuleStruct | |||
| from .args_translator import ArgsTranslationHelper | |||
| from ..common.global_context import GlobalContext | |||
| from ..common.outputs import BaseOutput, ModuleOutputManager | |||
| from ...common.exceptions import GeneratorError | |||
| from ..common.name_mgr import GlobalVarNameMgr | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module | |||
| from ..report_generator import ReportGenerator | |||
| from mindinsight.mindconverter.common.exceptions import GeneratorError | |||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | |||
| from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct | |||
| from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct | |||
| from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslationHelper | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseOutput, ModuleOutputManager | |||
| from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config | |||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr | |||
| from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, \ | |||
| get_imported_module | |||
| from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | |||
| class CodeStruct: | |||
| @@ -35,7 +37,6 @@ class CodeStruct: | |||
| Define the Code template for each module generated in the final output. | |||
| Each module has only one CodeStruct to its pattern. | |||
| """ | |||
| GLOBAL_CONTEXT = GlobalContext() | |||
| NOT_IN_SCOPE_OPT = dict() | |||
| def __init__(self, struct, repeated_submodules=None): | |||
| @@ -104,7 +105,7 @@ class CodeStruct: | |||
| elif isinstance(struct, ModuleStruct): | |||
| # check if this instance generated CodeStruct | |||
| if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None: | |||
| if GlobalContext().code_structs.get(struct.pattern_id) is None: | |||
| CodeStruct(struct, repeated_submodules) | |||
| code_line_init = struct.code_line_in_init() | |||
| @@ -138,10 +139,10 @@ class CodeStruct: | |||
| returns = list(set(returns)) | |||
| else: | |||
| returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \ | |||
| else [code_line_construct[-1].replace(' ', '').split('=')[0]] | |||
| else [code_line_construct[-1].replace(' ', '').split('=')[0]] | |||
| self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" | |||
| self.new_line = f"{NEW_LINE * 2}" | |||
| self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self | |||
| GlobalContext().code_structs[md_struct.pattern_id] = self | |||
| class Generator: | |||
| @@ -482,7 +483,7 @@ class Generator: | |||
| outputs.append(line) | |||
| formatted_code, _ = FormatCode("\n".join(outputs), | |||
| style_config=CodeFormatConfig.PEP8.value) | |||
| style_config=mindspore_yapf_config()) | |||
| report_generator = ReportGenerator() | |||
| report = report_generator.gen_report(formatted_code) | |||
| @@ -589,7 +590,7 @@ class Generator: | |||
| output_obj.idx_in_ms_user[nd_struct.identifier] = idx | |||
| # set this output to be returned to external | |||
| output_obj.to_external = not(nd_struct.check_target_node_internal( | |||
| output_obj.to_external = not (nd_struct.check_target_node_internal( | |||
| self._global_context.outputs_storage.onnx_name(inp) | |||
| )) | |||
| @@ -31,11 +31,10 @@ class ModuleStruct: | |||
| Define a module struct which stores all info. to generate statement. | |||
| Args: | |||
| args (list): A list of node structs. | |||
| nd_struct_list (list): A list of node structs. | |||
| init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct. | |||
| parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as. | |||
| """ | |||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||
| def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): | |||
| """Init. a module by NodeStructs.""" | |||
| @@ -247,7 +246,7 @@ class ModuleStruct: | |||
| self._module_structs += md_structs | |||
| tail_md = md_structs[-1] | |||
| else: | |||
| raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format( | |||
| raise TypeError("ModuleStruct cannot add an unsupported Type {} to module_structs list.".format( | |||
| type(md_structs))) | |||
| # update tail node and index | |||
| if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: | |||
| @@ -318,12 +317,7 @@ class ModuleStruct: | |||
| return ret | |||
| def code_line_in_init(self): | |||
| """ | |||
| Initialization line of code in module init block. | |||
| Args: | |||
| override_formal_val (dict): Indicate which args should be renamed for passing value from upper level. | |||
| """ | |||
| """Initialization line of code in module init block.""" | |||
| left = "self.{}".format(self.ms_var_name) | |||
| args_list = list() | |||
| # Load args in init statement. | |||
| @@ -338,7 +332,7 @@ class ModuleStruct: | |||
| else: | |||
| args_list += self._fragment.actual_args | |||
| right = f"{self.class_name}({', '.join(args_list)})" | |||
| return (left, right) | |||
| return left, right | |||
| def code_line_in_construct(self, inputs=None): | |||
| """Construct line of code in module construct block.""" | |||
| @@ -356,7 +350,7 @@ class ModuleStruct: | |||
| if isinstance(inputs, str): | |||
| inputs = [inputs] | |||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | |||
| return (left, right) | |||
| return left, right | |||
| @property | |||
| def node_structs(self): | |||
| @@ -463,8 +457,8 @@ class ModuleStruct: | |||
| """Return the class name for generating code of this module.""" | |||
| if self.pattern_id == -1: | |||
| return "Model" | |||
| if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None: | |||
| class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) | |||
| if GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) is not None: | |||
| class_name = GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) | |||
| else: | |||
| class_name = "Module{}".format(self.pattern_id) | |||
| return class_name | |||
| @@ -23,6 +23,7 @@ from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..common.global_context import GlobalContext | |||
| from ...common.exceptions import GeneratorError | |||
| class NodeStruct: | |||
| """ | |||
| Define a node struct which stores all info. to generate statement. | |||
| @@ -34,10 +35,10 @@ class NodeStruct: | |||
| You can pass as many args as possible and the Node Struct will update | |||
| by arguments order. | |||
| """ | |||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||
| def __init__(self, args): | |||
| # define attributes here | |||
| self.global_context_mgr = GlobalContext() | |||
| self._identifier = None | |||
| self._fragment = None | |||
| self._args_translator = None | |||
| @@ -74,7 +75,7 @@ class NodeStruct: | |||
| """Get the original topological index in the onnx graph.""" | |||
| ori_name = self._fragment.metadata.get('source') | |||
| self.onnx_name = ori_name | |||
| return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) | |||
| return GlobalContext().onnx_node_name_to_topo_idx.get(ori_name) | |||
| def update_var_name(self, idx=None): | |||
| """ | |||
| @@ -83,6 +84,7 @@ class NodeStruct: | |||
| Args: | |||
| idx (int): The index of the node in this module. | |||
| """ | |||
| def _remove_op_header(op_name): | |||
| """Remove op header which indicating their sources of op set.""" | |||
| op_name = op_name.replace('nn.', '') | |||
| @@ -112,7 +114,7 @@ class NodeStruct: | |||
| self._fragment = FragmentHandler(frag) | |||
| if self.ms_op: | |||
| idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count | |||
| idx = GlobalContext().latest_node_struct_count | |||
| self.update_var_name(idx=idx) | |||
| def _set_scope_from_identifier(self): | |||
| @@ -142,9 +144,7 @@ class NodeStruct: | |||
| Args: | |||
| arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | |||
| force_ready (bool): Force this NodeStruct is ready to generate. | |||
| """ | |||
| if isinstance(arg, OnnxGraphNode): | |||
| self._update_from_onnx_gn(arg) | |||
| elif isinstance(arg, NewFragment): | |||
| @@ -168,7 +168,7 @@ class NodeStruct: | |||
| self._identifier = s | |||
| self._set_scope_from_identifier() | |||
| self.topo_idx = self.ori_topo_idx() | |||
| self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self | |||
| GlobalContext().onnx_node_name_to_node_struct_map[self.onnx_name] = self | |||
| @property | |||
| def fragment(self): | |||
| @@ -181,7 +181,7 @@ class NodeStruct: | |||
| Set the Node fragment. | |||
| Args: | |||
| s (NodeFragment): The node identifier string. | |||
| frag (NodeFragment): The node identifier string. | |||
| """ | |||
| self._fragment = frag | |||
| @@ -198,7 +198,7 @@ class NodeStruct: | |||
| @property | |||
| def onnx_node(self): | |||
| """Return the original onnx node reference.""" | |||
| return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) | |||
| return GlobalContext().onnx_nodes_collection.get(self.onnx_name) | |||
| @property | |||
| def ms_op(self): | |||
| @@ -241,7 +241,7 @@ class NodeStruct: | |||
| ret = [] | |||
| precursor_nodes_names = self.precursor_nodes_names | |||
| for pre_node_name in precursor_nodes_names: | |||
| nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| ret.append(nd_struct) | |||
| return ret | |||
| @@ -255,7 +255,7 @@ class NodeStruct: | |||
| """Return the node struct instances of successor nodes.""" | |||
| ret = [] | |||
| for pre_node_name in self.successor_nodes_names: | |||
| nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| ret.append(nd_struct) | |||
| return ret | |||
| @@ -312,11 +312,11 @@ class NodeStruct: | |||
| inputs = self.matched_inputs | |||
| # Check original onnx node's input to ensure double inputs are not ignored | |||
| original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name) | |||
| original_inputs = GlobalContext().onnx_node_inputs.get(self.onnx_name) | |||
| new_inputs = [] | |||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | |||
| occurence = original_inputs.count(prec_node) | |||
| for _ in range(occurence): | |||
| occurrence = original_inputs.count(prec_node) | |||
| for _ in range(occurrence): | |||
| new_inputs.append(inputs[idx]) | |||
| inputs = new_inputs | |||
| @@ -360,12 +360,12 @@ class NodeStruct: | |||
| Args: | |||
| name (str): Can accept both node identifier or original onnx node name. | |||
| """ | |||
| target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ | |||
| or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) | |||
| target_nd_struct = GlobalContext().node_struct_collections.get(name) \ | |||
| or GlobalContext().onnx_node_name_to_node_struct_map.get(name) | |||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||
| return False | |||
| if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')): | |||
| if target_nd_struct is None and (name in GlobalContext().onnx_graph_info.get('graph_inputs')): | |||
| return False | |||
| if target_nd_struct is None: | |||
| @@ -34,7 +34,7 @@ class Pattern: | |||
| # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, | |||
| # the pattern will get additional score. | |||
| self.additional_score = 0 | |||
| self.know_module_name = None | |||
| self.known_module_name = None | |||
| def insert(self, idx, seq_len): | |||
| """ | |||