|
|
|
@@ -0,0 +1,202 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================== |
|
|
|
"""Define GlobalContext class to save required resources during whole conversion procedure.""" |
|
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
|
|
|
|
class Singleton(type): |
|
|
|
"""Metaclass to make the globalcontext single instance.""" |
|
|
|
_instances = {} |
|
|
|
|
|
|
|
def __call__(cls, *args, **kwargs): |
|
|
|
if cls not in cls._instances: |
|
|
|
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) |
|
|
|
return cls._instances[cls] |
|
|
|
|
|
|
|
|
|
|
|
class GlobalContext(metaclass=Singleton): |
|
|
|
""" |
|
|
|
A universal global context library for easy data exchanging in MindConverter. |
|
|
|
|
|
|
|
Note: |
|
|
|
In order to avoid reference loops, it is unable to check functions |
|
|
|
arguments' type in GlobalContext. You MUST check all inputs |
|
|
|
have its correct type before calling functions. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
# Define data stored from onnx_utils |
|
|
|
# Key as Onnx Name |
|
|
|
self._onnx_nodes_collection = OrderedDict() |
|
|
|
# key is topo_idx, value is onnx_node_name. |
|
|
|
self._onnx_nodes_topo_index = dict() |
|
|
|
self._onnx_tensors_collection = dict() |
|
|
|
|
|
|
|
# Define data stored from generator |
|
|
|
# Key as Node Identifier |
|
|
|
self.node_struct_collections = OrderedDict() |
|
|
|
self.node_struct_adder_counter = 0 |
|
|
|
# Define onnx_utils <---> generator mapping |
|
|
|
self.node_struct_to_onnx_node_map = dict() |
|
|
|
self.onnx_node_to_node_struct_map = dict() |
|
|
|
|
|
|
|
# Define Module pattern to customize name mapping |
|
|
|
self.module_customized_name = dict() |
|
|
|
|
|
|
|
# Define Fragments |
|
|
|
self.node_fragments = OrderedDict() |
|
|
|
self.module_fragments = OrderedDict() |
|
|
|
|
|
|
|
# Define Structs |
|
|
|
# key is pattern_id, value is [ModuleStructs] |
|
|
|
self.module_structs = dict() |
|
|
|
self.code_structs = dict() |
|
|
|
|
|
|
|
# Define extra inputs |
|
|
|
# key is target node (which use this opt), value is opt_var_name |
|
|
|
self.extra_input_dict = dict() |
|
|
|
|
|
|
|
def get_onnx_node_from_identifier(self, identifier): |
|
|
|
"""Return an OnnxUtils defined node by its identifier.""" |
|
|
|
onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) |
|
|
|
return self.onnx_nodes_collection.get(onnx_node_name) |
|
|
|
|
|
|
|
def get_onnx_node_from_onnx_topo_idx(self, idx): |
|
|
|
"""Return an OnnxUtils defined node name by its topological index.""" |
|
|
|
return self._onnx_nodes_topo_index.get(idx) |
|
|
|
|
|
|
|
def get_onnx_tensor(self, tensor_name): |
|
|
|
"""Return an OnnxUtils defined tensor.""" |
|
|
|
return self.onnx_tensors_collection.get(tensor_name) |
|
|
|
|
|
|
|
def get_identifier_from_onnx_node_name(self, node_name): |
|
|
|
"""Return the node identifier by Onnx Node name.""" |
|
|
|
identifier = self.onnx_node_to_node_struct_map.get(node_name) |
|
|
|
return identifier |
|
|
|
|
|
|
|
@property |
|
|
|
def onnx_nodes_collection(self) -> OrderedDict: |
|
|
|
""" |
|
|
|
Return the onnx nodes collections. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, dictionary contains all OnnxUtils defined onnx nodes. |
|
|
|
""" |
|
|
|
return self._onnx_nodes_collection |
|
|
|
|
|
|
|
@onnx_nodes_collection.setter |
|
|
|
def onnx_nodes_collection(self, arg): |
|
|
|
""" |
|
|
|
Set the onnx nodes collection. |
|
|
|
""" |
|
|
|
if isinstance(arg, OrderedDict): |
|
|
|
self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader |
|
|
|
else: |
|
|
|
raise TypeError("GlobalContext received an unsupport variable to assign to onnx_nodes_collection.") |
|
|
|
|
|
|
|
@property |
|
|
|
def onnx_nodes_topo_index(self) -> dict: |
|
|
|
"Return the onnx nodes and topological index." |
|
|
|
return self._onnx_nodes_topo_index |
|
|
|
|
|
|
|
@onnx_nodes_topo_index.setter |
|
|
|
def onnx_nodes_topo_index(self, index_list): |
|
|
|
if not isinstance(index_list, list): |
|
|
|
raise TypeError("The argument index_list must be a list of tuple (index, onnx_node_name).") |
|
|
|
if not isinstance(index_list[0], tuple): |
|
|
|
raise TypeError("The item in index_list must by a tuple of (index, onnx_node_name)") |
|
|
|
for (topo_idx, onnx_node_name) in index_list: |
|
|
|
self._onnx_nodes_topo_index[topo_idx] = onnx_node_name |
|
|
|
|
|
|
|
@property |
|
|
|
def onnx_tensors_collection(self): |
|
|
|
return self.onnx_tensors_collection |
|
|
|
|
|
|
|
@onnx_tensors_collection.setter |
|
|
|
def onnx_tensors_collection(self, arg): |
|
|
|
if isinstance(arg, dict): |
|
|
|
self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader |
|
|
|
else: |
|
|
|
raise TypeError("GlobalContext received an unsupport variable to assign to onnx_tensors_collection.") |
|
|
|
|
|
|
|
@property |
|
|
|
def latest_node_struct_count(self): |
|
|
|
ret = self.node_struct_adder_counter |
|
|
|
self.node_struct_adder_counter += 1 |
|
|
|
return ret |
|
|
|
|
|
|
|
def get_extra_input(self, topo_idx) -> list: |
|
|
|
""" |
|
|
|
Get the extra input of the node topological index provided. |
|
|
|
|
|
|
|
Args: |
|
|
|
topo_idx (int): The topological index of the node required extra input. |
|
|
|
""" |
|
|
|
return self.extra_input_dict.get(topo_idx) |
|
|
|
|
|
|
|
def add_extra_input(self, target_topo_idx, opt_var_name): |
|
|
|
""" |
|
|
|
Add the extra input(s) required for the target node. |
|
|
|
|
|
|
|
Args: |
|
|
|
target_topo_idx (int): The index of node which requires the input. |
|
|
|
opt_var_name (Union[str, list]): The output(s) name the target node will use. |
|
|
|
""" |
|
|
|
if isinstance(opt_var_name, str): |
|
|
|
opt_var_name = [opt_var_name] |
|
|
|
if isinstance(opt_var_name, list): |
|
|
|
self.extra_input_dict[target_topo_idx] = opt_var_name |
|
|
|
else: |
|
|
|
raise TypeError("Global Context does not support the type {} of opt_var_name.".format(type(opt_var_name))) |
|
|
|
|
|
|
|
def get_module_customized_name(self, pattern_id) -> str: |
|
|
|
""" |
|
|
|
Get the customized name of the module with pattern id provied. |
|
|
|
|
|
|
|
Args: |
|
|
|
pattern_id (int): The pattern the module belongs to. |
|
|
|
|
|
|
|
Returns, |
|
|
|
str, the customized name of the module. |
|
|
|
""" |
|
|
|
return self.module_customized_name.get(pattern_id) |
|
|
|
|
|
|
|
def set_module_customized_name(self, pattern_id, customized_name): |
|
|
|
""" |
|
|
|
Set the customized name of the module with pattern id provided. |
|
|
|
|
|
|
|
Args: |
|
|
|
pattern_id (int): The pattern id the module has. |
|
|
|
customized_name (str): The customized name of the module. |
|
|
|
""" |
|
|
|
self.module_customized_name[pattern_id] = customized_name |
|
|
|
|
|
|
|
def get_node_fragment(self, identifier): |
|
|
|
return self.node_fragments.get(identifier) |
|
|
|
|
|
|
|
def add_code_fragment(self, identifier, frag): |
|
|
|
self.node_fragments[identifier] = frag |
|
|
|
|
|
|
|
def get_module_fragment(self, identifier): |
|
|
|
return self.module_fragments.get(identifier) |
|
|
|
|
|
|
|
def add_module_fragment(self, identifier, frag): |
|
|
|
self.module_fragments[identifier] = frag |
|
|
|
|
|
|
|
def add_module_struct(self, pattern_id, module_struct): |
|
|
|
if self.module_structs.get(pattern_id) is None: |
|
|
|
self.module_structs[pattern_id] = [module_struct] |
|
|
|
else: |
|
|
|
self.module_structs[pattern_id].append(module_struct) |