diff --git a/mindinsight/mindconverter/graph_based_converter/common/global_context.py b/mindinsight/mindconverter/graph_based_converter/common/global_context.py new file mode 100644 index 00000000..c0ce4c71 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/common/global_context.py @@ -0,0 +1,202 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define GlobalContext class to save required resources during whole conversion procedure.""" +from collections import OrderedDict + + +class Singleton(type): + """Metaclass to make the globalcontext single instance.""" + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class GlobalContext(metaclass=Singleton): + """ + A universal global context library for easy data exchanging in MindConverter. + + Note: + In order to avoid reference loops, it is unable to check functions + arguments' type in GlobalContext. You MUST check all inputs + have its correct type before calling functions. + """ + + def __init__(self): + # Define data stored from onnx_utils + # Key as Onnx Name + self._onnx_nodes_collection = OrderedDict() + # key is topo_idx, value is onnx_node_name. + self._onnx_nodes_topo_index = dict() + self._onnx_tensors_collection = dict() + + # Define data stored from generator + # Key as Node Identifier + self.node_struct_collections = OrderedDict() + self.node_struct_adder_counter = 0 + # Define onnx_utils <---> generator mapping + self.node_struct_to_onnx_node_map = dict() + self.onnx_node_to_node_struct_map = dict() + + # Define Module pattern to customize name mapping + self.module_customized_name = dict() + + # Define Fragments + self.node_fragments = OrderedDict() + self.module_fragments = OrderedDict() + + # Define Structs + # key is pattern_id, value is [ModuleStructs] + self.module_structs = dict() + self.code_structs = dict() + + # Define extra inputs + # key is target node (which use this opt), value is opt_var_name + self.extra_input_dict = dict() + + def get_onnx_node_from_identifier(self, identifier): + """Return an OnnxUtils defined node by its identifier.""" + onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) + return self.onnx_nodes_collection.get(onnx_node_name) + + def get_onnx_node_from_onnx_topo_idx(self, idx): + """Return an OnnxUtils defined node name by its topological index.""" + return self._onnx_nodes_topo_index.get(idx) + + def get_onnx_tensor(self, tensor_name): + """Return an OnnxUtils defined tensor.""" + return self.onnx_tensors_collection.get(tensor_name) + + def get_identifier_from_onnx_node_name(self, node_name): + """Return the node identifier by Onnx Node name.""" + identifier = self.onnx_node_to_node_struct_map.get(node_name) + return identifier + + @property + def onnx_nodes_collection(self) -> OrderedDict: + """ + Return the onnx nodes collections. + + Returns: + dict, dictionary contains all OnnxUtils defined onnx nodes. + """ + return self._onnx_nodes_collection + + @onnx_nodes_collection.setter + def onnx_nodes_collection(self, arg): + """ + Set the onnx nodes collection. + """ + if isinstance(arg, OrderedDict): + self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader + else: + raise TypeError("GlobalContext received an unsupport variable to assign to onnx_nodes_collection.") + + @property + def onnx_nodes_topo_index(self) -> dict: + "Return the onnx nodes and topological index." + return self._onnx_nodes_topo_index + + @onnx_nodes_topo_index.setter + def onnx_nodes_topo_index(self, index_list): + if not isinstance(index_list, list): + raise TypeError("The argument index_list must be a list of tuple (index, onnx_node_name).") + if not isinstance(index_list[0], tuple): + raise TypeError("The item in index_list must by a tuple of (index, onnx_node_name)") + for (topo_idx, onnx_node_name) in index_list: + self._onnx_nodes_topo_index[topo_idx] = onnx_node_name + + @property + def onnx_tensors_collection(self): + return self.onnx_tensors_collection + + @onnx_tensors_collection.setter + def onnx_tensors_collection(self, arg): + if isinstance(arg, dict): + self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader + else: + raise TypeError("GlobalContext received an unsupport variable to assign to onnx_tensors_collection.") + + @property + def latest_node_struct_count(self): + ret = self.node_struct_adder_counter + self.node_struct_adder_counter += 1 + return ret + + def get_extra_input(self, topo_idx) -> list: + """ + Get the extra input of the node topological index provided. + + Args: + topo_idx (int): The topological index of the node required extra input. + """ + return self.extra_input_dict.get(topo_idx) + + def add_extra_input(self, target_topo_idx, opt_var_name): + """ + Add the extra input(s) required for the target node. + + Args: + target_topo_idx (int): The index of node which requires the input. + opt_var_name (Union[str, list]): The output(s) name the target node will use. + """ + if isinstance(opt_var_name, str): + opt_var_name = [opt_var_name] + if isinstance(opt_var_name, list): + self.extra_input_dict[target_topo_idx] = opt_var_name + else: + raise TypeError("Global Context does not support the type {} of opt_var_name.".format(type(opt_var_name))) + + def get_module_customized_name(self, pattern_id) -> str: + """ + Get the customized name of the module with pattern id provied. + + Args: + pattern_id (int): The pattern the module belongs to. + + Returns, + str, the customized name of the module. + """ + return self.module_customized_name.get(pattern_id) + + def set_module_customized_name(self, pattern_id, customized_name): + """ + Set the customized name of the module with pattern id provided. + + Args: + pattern_id (int): The pattern id the module has. + customized_name (str): The customized name of the module. + """ + self.module_customized_name[pattern_id] = customized_name + + def get_node_fragment(self, identifier): + return self.node_fragments.get(identifier) + + def add_code_fragment(self, identifier, frag): + self.node_fragments[identifier] = frag + + def get_module_fragment(self, identifier): + return self.module_fragments.get(identifier) + + def add_module_fragment(self, identifier, frag): + self.module_fragments[identifier] = frag + + def add_module_struct(self, pattern_id, module_struct): + if self.module_structs.get(pattern_id) is None: + self.module_structs[pattern_id] = [module_struct] + else: + self.module_structs[pattern_id].append(module_struct)