From: @liangtianshu Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -40,8 +40,10 @@ class GlobalContext(metaclass=Singleton): | |||||
| # Define data stored from onnx_utils | # Define data stored from onnx_utils | ||||
| # Key as Onnx Name | # Key as Onnx Name | ||||
| self._onnx_nodes_collection = OrderedDict() | self._onnx_nodes_collection = OrderedDict() | ||||
| # key is topo_idx, value is onnx_node_name. | |||||
| # key is topo_idx, value is onnx_node_name | |||||
| self._onnx_nodes_topo_index = dict() | self._onnx_nodes_topo_index = dict() | ||||
| self.onnx_node_name_to_topo_idx = dict() | |||||
| self.onnx_node_inputs = dict() | |||||
| self._onnx_tensors_collection = dict() | self._onnx_tensors_collection = dict() | ||||
| # Define data stored from generator | # Define data stored from generator | ||||
| @@ -50,7 +52,7 @@ class GlobalContext(metaclass=Singleton): | |||||
| self.node_struct_adder_counter = 0 | self.node_struct_adder_counter = 0 | ||||
| # Define onnx_utils <---> generator mapping | # Define onnx_utils <---> generator mapping | ||||
| self.node_struct_to_onnx_node_map = dict() | self.node_struct_to_onnx_node_map = dict() | ||||
| self.onnx_node_to_node_struct_map = dict() | |||||
| self.onnx_node_name_to_node_struct_map = dict() | |||||
| # Define Module pattern to customize name mapping | # Define Module pattern to customize name mapping | ||||
| self.module_customized_name = dict() | self.module_customized_name = dict() | ||||
| @@ -59,6 +61,8 @@ class GlobalContext(metaclass=Singleton): | |||||
| self.node_fragments = OrderedDict() | self.node_fragments = OrderedDict() | ||||
| self.module_fragments = OrderedDict() | self.module_fragments = OrderedDict() | ||||
| # Define Known module mapping | |||||
| self.known_module_name = dict() | |||||
| # Define Structs | # Define Structs | ||||
| # key is pattern_id, value is [ModuleStructs] | # key is pattern_id, value is [ModuleStructs] | ||||
| self.module_structs = dict() | self.module_structs = dict() | ||||
| @@ -83,7 +87,7 @@ class GlobalContext(metaclass=Singleton): | |||||
| def get_identifier_from_onnx_node_name(self, node_name): | def get_identifier_from_onnx_node_name(self, node_name): | ||||
| """Return the node identifier by Onnx Node name.""" | """Return the node identifier by Onnx Node name.""" | ||||
| identifier = self.onnx_node_to_node_struct_map.get(node_name) | |||||
| identifier = self.onnx_node_name_to_node_struct_map.get(node_name) | |||||
| return identifier | return identifier | ||||
| @property | @property | ||||
| @@ -98,9 +102,7 @@ class GlobalContext(metaclass=Singleton): | |||||
| @onnx_nodes_collection.setter | @onnx_nodes_collection.setter | ||||
| def onnx_nodes_collection(self, arg): | def onnx_nodes_collection(self, arg): | ||||
| """ | |||||
| Set the onnx nodes collection. | |||||
| """ | |||||
| """Set the onnx nodes collection.""" | |||||
| if isinstance(arg, OrderedDict): | if isinstance(arg, OrderedDict): | ||||
| self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader | self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader | ||||
| else: | else: | ||||
| @@ -108,11 +110,18 @@ class GlobalContext(metaclass=Singleton): | |||||
| @property | @property | ||||
| def onnx_nodes_topo_index(self) -> dict: | def onnx_nodes_topo_index(self) -> dict: | ||||
| "Return the onnx nodes and topological index." | |||||
| """Return the onnx nodes and topological index.""" | |||||
| return self._onnx_nodes_topo_index | return self._onnx_nodes_topo_index | ||||
| @onnx_nodes_topo_index.setter | @onnx_nodes_topo_index.setter | ||||
| def onnx_nodes_topo_index(self, index_list): | def onnx_nodes_topo_index(self, index_list): | ||||
| """ | |||||
| Set the onnx nodes and topological index. | |||||
| Args: | |||||
| index_list (list[tuple[int, str]]): a list of tuple contains the topological index and onnx node name. | |||||
| """ | |||||
| if not isinstance(index_list, list): | if not isinstance(index_list, list): | ||||
| raise TypeError("The argument index_list must be a list of tuple (index, onnx_node_name).") | raise TypeError("The argument index_list must be a list of tuple (index, onnx_node_name).") | ||||
| if not isinstance(index_list[0], tuple): | if not isinstance(index_list[0], tuple): | ||||
| @@ -122,10 +131,17 @@ class GlobalContext(metaclass=Singleton): | |||||
| @property | @property | ||||
| def onnx_tensors_collection(self): | def onnx_tensors_collection(self): | ||||
| """Return the onnx tensors collection.""" | |||||
| return self.onnx_tensors_collection | return self.onnx_tensors_collection | ||||
| @onnx_tensors_collection.setter | @onnx_tensors_collection.setter | ||||
| def onnx_tensors_collection(self, arg): | def onnx_tensors_collection(self, arg): | ||||
| """ | |||||
| Set the onnx tensors collection by OnnxDataLoader. | |||||
| Args: | |||||
| arg (dict): The OnnxDataLoader generated tensors_dict. | |||||
| """ | |||||
| if isinstance(arg, dict): | if isinstance(arg, dict): | ||||
| self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader | self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader | ||||
| else: | else: | ||||
| @@ -133,6 +149,12 @@ class GlobalContext(metaclass=Singleton): | |||||
| @property | @property | ||||
| def latest_node_struct_count(self): | def latest_node_struct_count(self): | ||||
| """ | |||||
| Return the latest node struct count. | |||||
| Note: | |||||
| The counter will increase by 1 to tracking the number of nodes added. | |||||
| """ | |||||
| ret = self.node_struct_adder_counter | ret = self.node_struct_adder_counter | ||||
| self.node_struct_adder_counter += 1 | self.node_struct_adder_counter += 1 | ||||
| return ret | return ret | ||||
| @@ -184,18 +206,29 @@ class GlobalContext(metaclass=Singleton): | |||||
| self.module_customized_name[pattern_id] = customized_name | self.module_customized_name[pattern_id] = customized_name | ||||
| def get_node_fragment(self, identifier): | def get_node_fragment(self, identifier): | ||||
| """Return the node fragment by identifier.""" | |||||
| return self.node_fragments.get(identifier) | return self.node_fragments.get(identifier) | ||||
| def add_code_fragment(self, identifier, frag): | def add_code_fragment(self, identifier, frag): | ||||
| """Add the node fragment by identifier.""" | |||||
| self.node_fragments[identifier] = frag | self.node_fragments[identifier] = frag | ||||
| def get_module_fragment(self, identifier): | def get_module_fragment(self, identifier): | ||||
| """Return the module fragment by identifier.""" | |||||
| return self.module_fragments.get(identifier) | return self.module_fragments.get(identifier) | ||||
| def add_module_fragment(self, identifier, frag): | def add_module_fragment(self, identifier, frag): | ||||
| """Add the module fragment by identifier.""" | |||||
| self.module_fragments[identifier] = frag | self.module_fragments[identifier] = frag | ||||
| def add_module_struct(self, pattern_id, module_struct): | def add_module_struct(self, pattern_id, module_struct): | ||||
| """ | |||||
| Add module struct by its pattern_id. | |||||
| Args: | |||||
| pattern_id (int): The pattern which represents the structure of the module. | |||||
| module_struct (ModuleStruct): The ModuleStruct instance. | |||||
| """ | |||||
| if self.module_structs.get(pattern_id) is None: | if self.module_structs.get(pattern_id) is None: | ||||
| self.module_structs[pattern_id] = [module_struct] | self.module_structs[pattern_id] = [module_struct] | ||||
| else: | else: | ||||
| @@ -135,3 +135,18 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str, | |||||
| if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): | if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): | ||||
| return False | return False | ||||
| return True | return True | ||||
| def get_dict_key_by_value(val, dic): | |||||
| """ | |||||
| Return the first appeared key of a dictionay by given value. | |||||
| Args: | |||||
| val (Any): Value of the key. | |||||
| dic (dict): Dictionary to be checked. | |||||
| Returns: | |||||
| Any, key of the given value. | |||||
| """ | |||||
| for d_key, d_val in dic.items(): | |||||
| if d_val == val: | |||||
| return d_key | |||||
| return None | |||||
| @@ -0,0 +1,111 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Generator module.""" | |||||
| __all__ = ["batch_add_nodes"] | |||||
| import re | |||||
| import copy | |||||
| from .generator import Generator, CodeStruct | |||||
| from ..common.code_fragment import CodeFragment | |||||
| def _tf_model_node_name_reformat(node, node_name): | |||||
| """ | |||||
| Rename the node name by combining scope name and its original name. | |||||
| Args: | |||||
| node (OnnxGraphNode): OnnxGraphNode instance. | |||||
| node_name (str): node name saved in Graph. | |||||
| Returns: | |||||
| str, re-formatted node name. | |||||
| """ | |||||
| scope_name = node.scope_name | |||||
| new_name = None | |||||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | |||||
| match = re.match(regex, scope_name) | |||||
| parent = match.group("parent") | |||||
| node_name = '$' + node_name.replace('/', '::') + '$' | |||||
| if scope_name: | |||||
| new_name = parent + node_name | |||||
| return new_name | |||||
| return node_name | |||||
| def batch_add_nodes(graph_obj, mapper) -> Generator: | |||||
| """ | |||||
| Add nodes to Generator in batch mode. | |||||
| Args: | |||||
| graph_obj (Graph): Graph obj. | |||||
| mapper (Mapper): Mapper of third party framework and MindSpore. | |||||
| """ | |||||
| generator_inst = Generator() | |||||
| for node_name in graph_obj.nodes_in_topological_order: | |||||
| node_inst = graph_obj.get_node(node_name) | |||||
| node_input = graph_obj.get_input_shape(node_name) | |||||
| node_output = graph_obj.get_output_shape(node_name) | |||||
| if not node_input: | |||||
| raise ValueError("Unable to get the node's inputs from Graph object.") | |||||
| node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) | |||||
| node_name = node_name_with_scope | |||||
| node_inst.add_input_and_output_shape(node_input, node_output) | |||||
| op_name, params, settings, weights = _convert_params(node_inst, mapper) | |||||
| generator_inst.add_node( | |||||
| node_name, | |||||
| node_instance=node_inst, | |||||
| node_fragment=CodeFragment(op_name, params, | |||||
| settings, | |||||
| node_inst.input_shape, | |||||
| node_inst.output_shape, | |||||
| weights) | |||||
| ) | |||||
| return generator_inst | |||||
| def _convert_params(node, mapper): | |||||
| """ | |||||
| Call mapper to convert node's params from ONNX to MindSpore. | |||||
| Args: | |||||
| node (GraphNode): Our defined GraphNode instance. | |||||
| mapper (Mapper): The mapper instance which indicating conversion method. | |||||
| Returns: | |||||
| str, op name in MindSpore | |||||
| dict, MindSpore parameters | |||||
| dict, MindSpore settings | |||||
| dict, weights of the node | |||||
| """ | |||||
| params = copy.deepcopy(node.node_params) | |||||
| params.update({"input_shape": node.input_shape, | |||||
| "output_shape": node.output_shape}) | |||||
| op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name, | |||||
| params=params, | |||||
| weights=node.weight) | |||||
| if "input_shape" in ms_params: | |||||
| ms_params.pop("input_shape") | |||||
| if "output_shape" in ms_params: | |||||
| ms_params.pop("output_shape") | |||||
| if op_in_ms: | |||||
| return op_in_ms, ms_params, ms_settings, weights | |||||
| return node.op_name, node.node_params, dict(), dict() | |||||
| @@ -0,0 +1,248 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define arguments translation related operations for params name changing.""" | |||||
| import copy | |||||
| class ArgsTranslation: | |||||
| """Define a universal arguments translation manager.""" | |||||
| def __init__(self, original_actual_args: dict, var_name: str, translated_args: list): | |||||
| """ | |||||
| Init the ArgsTranslation. | |||||
| Args: | |||||
| original_actual_args (dict): The full original args from fragments. | |||||
| var_name (str): The var name for current Node / Module. | |||||
| translated_args (list): The list of args need to translate to formal args. | |||||
| """ | |||||
| if not var_name: | |||||
| raise ValueError("Initialize ArgsTranslation requires the var_name.") | |||||
| self.var_name = var_name | |||||
| self.actual_args = dict() # e.g. key is 'num_features', value is 2048 | |||||
| self.formal_args = dict() # e.g. key is 'num_features', value is 'var_name_num_features'} | |||||
| self.formal_args_values = dict() # e.g. key 'var_name_num_features', value 2048. Value use for up-level | |||||
| self.actual_args_backup = dict() # backup actual args before translation | |||||
| self.actual_args_to_str_list = list() | |||||
| self.formal_args_to_str_list = list() | |||||
| self.formal_args_values_to_str_list = list() | |||||
| self.actual_args_backup_to_str_list = list() | |||||
| if all([original_actual_args, translated_args]): | |||||
| # MUST ensure only one var_name in a scope. | |||||
| for arg_name, arg_value in original_actual_args.items(): | |||||
| if arg_name in translated_args: | |||||
| formal_arg_name = '_'.join([var_name, arg_name]) | |||||
| self.formal_args[arg_name] = formal_arg_name | |||||
| self.formal_args_values[formal_arg_name] = arg_value | |||||
| else: | |||||
| self.actual_args[arg_name] = arg_value | |||||
| self.make_str() | |||||
| @staticmethod | |||||
| def dict_data_to_args_str_list(any_dict): | |||||
| """ | |||||
| Output a list of string of dict data by "key=value" format. | |||||
| Args: | |||||
| any_dict (dict): Any dictionary | |||||
| Returns: | |||||
| list, the list of strings showing dictionary as "key=value" format. | |||||
| """ | |||||
| ret = [] | |||||
| for key, val in any_dict.items(): | |||||
| ret.append('='.join([key, str(val)])) | |||||
| return ret | |||||
| def make_str(self): | |||||
| """Make string used in code generation.""" | |||||
| self.actual_args_to_str_list = list() | |||||
| self.formal_args_to_str_list = list() | |||||
| self.formal_args_values_to_str_list = list() | |||||
| self.actual_args_backup_to_str_list = list() | |||||
| if self.actual_args: | |||||
| self.actual_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args) | |||||
| if self.formal_args: | |||||
| self.formal_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args) | |||||
| if self.formal_args_values: | |||||
| self.formal_args_values_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args_values) | |||||
| if self.actual_args_backup: | |||||
| self.actual_args_backup_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args_backup) | |||||
| def __repr__(self): | |||||
| return str({ | |||||
| "address": hex(id(self)), | |||||
| "var_name": self.var_name, | |||||
| "actual_args": self.actual_args, | |||||
| "actual_bak": self.actual_args_backup, | |||||
| "formal_args": self.formal_args, | |||||
| "formal_val ": self.formal_args_values | |||||
| }) | |||||
| def set_actual_args_backup(self): | |||||
| """Backup the actual args before translating to formal.""" | |||||
| self.actual_args_backup = copy.deepcopy(self.actual_args) | |||||
| def deepcopy(self): | |||||
| """Return a deepcopy of self.""" | |||||
| return copy.deepcopy(self) | |||||
| def make_actual_arg_to_formal(self, actual_arg_name): | |||||
| """ | |||||
| Make the actual arg to a formal arg. | |||||
| Args: | |||||
| actual_arg_name (str): The name of the actual arg to be formal. | |||||
| """ | |||||
| val = self.actual_args.get(actual_arg_name) | |||||
| if val is None: | |||||
| raise ValueError("Unable to convert the actual arg to formal due to missing arg.") | |||||
| formal_arg_name = ('_').join([self.var_name, actual_arg_name]) | |||||
| self.actual_args.pop(actual_arg_name) | |||||
| self.formal_args[actual_arg_name] = formal_arg_name | |||||
| self.formal_args_values[formal_arg_name] = val | |||||
| self.make_str() | |||||
| def _update_dict_for_upper_level(self, d, upper_level_var_name): | |||||
| """Add upper level var name to key name of selected dictionary.""" | |||||
| new_d = dict() | |||||
| for arg_name, val in d.items(): | |||||
| new_arg_name = '_'.join([upper_level_var_name, arg_name]) # e.g. conv2d_0_in_channels_Module_3_0 | |||||
| new_d[new_arg_name] = val | |||||
| return new_d | |||||
| def escalate_to_upper_level(self, upper_level_var_name): | |||||
| """ | |||||
| Escalate this args translator for upper level module use. | |||||
| Note: | |||||
| You MUST deepcopy this translator first to avoid editing values in the original translator. | |||||
| """ | |||||
| # update all args name by adding upper_level_var_name. | |||||
| tmp_actual_args = self._update_dict_for_upper_level(self.actual_args, upper_level_var_name) | |||||
| tmp_formal_args = self._update_dict_for_upper_level(self.formal_args, upper_level_var_name) | |||||
| tmp_formal_args_values = self._update_dict_for_upper_level(self.formal_args_values, upper_level_var_name) | |||||
| self.actual_args = tmp_actual_args | |||||
| self.formal_args = tmp_formal_args | |||||
| self.formal_args_values = tmp_formal_args_values | |||||
| self.make_str() | |||||
| def make_formal_args_back_to_actual(self, formal_arg): | |||||
| """ | |||||
| Move the formal arg back to actual arg. | |||||
| Note: | |||||
| This does not reset the formal arg name back, | |||||
| Only used for module init statement. | |||||
| Args: | |||||
| formal_arg (str): formal argument name. | |||||
| """ | |||||
| if isinstance(formal_arg, str): | |||||
| val = self.formal_args_values.pop(formal_arg) | |||||
| self.actual_args[formal_arg] = val | |||||
| if isinstance(formal_arg, list): | |||||
| for arg in formal_arg: | |||||
| val = self.formal_args_values.pop(arg) | |||||
| self.actual_args[formal_arg] = val | |||||
| self.make_str() | |||||
| def take_formal_args_from_args_translator(self, args_translator, escalate_sub=False): | |||||
| """ | |||||
| Add submodule's or node's args translator to this translator. | |||||
| Args: | |||||
| args_translator (ArgsTranslation): submodule's or node's args translator. | |||||
| """ | |||||
| if escalate_sub: | |||||
| sub_args_translator = args_translator.deepcopy() | |||||
| sub_args_translator.escalate_to_upper_level(self.var_name) | |||||
| else: | |||||
| sub_args_translator = args_translator | |||||
| original_actual_args = sub_args_translator.formal_args_values | |||||
| self.actual_args.update(original_actual_args) | |||||
| self.make_str() | |||||
| def take_formal_args_from_nodes_and_submodules(self, args_translators: list, escalate_sub=False): | |||||
| """ | |||||
| Take all formal args from nodes and submodules from passed in args_translators. | |||||
| Args: | |||||
| args_translators (ArgsTranslation): A list of ArgsTranslation instances. | |||||
| escalate_sub (Bool): should escalate all formal args. Default: False | |||||
| """ | |||||
| for arg_t in args_translators: | |||||
| self.take_formal_args_from_args_translator(arg_t, escalate_sub=escalate_sub) | |||||
| class ArgsTranslationHelper: | |||||
| """Define operations related to ArgsTranslation instances.""" | |||||
| @staticmethod | |||||
| def find_formal_args_in_modules(args_translators): | |||||
| """ | |||||
| Find formal args among multiple args translators. | |||||
| Args: | |||||
| args_translators(list[ArgsTranslation]): a list of args translator to be checked. | |||||
| Returns: | |||||
| list, name of args to be formal. | |||||
| """ | |||||
| if len(args_translators) < 2: | |||||
| # only one args_translator provided, no formal args. | |||||
| return None | |||||
| ret = [] | |||||
| base_args_t = args_translators[0] | |||||
| for arg_name, arg_val in base_args_t.actual_args.items(): | |||||
| for args_t in args_translators[1:]: | |||||
| val = args_t.actual_args.get(arg_name) | |||||
| if val is None: | |||||
| raise ValueError("Unable to find the given args as the args translator is not consistent.") | |||||
| if val != arg_val: # val not equal | |||||
| ret.append(arg_name) | |||||
| break | |||||
| return ret | |||||
| @staticmethod | |||||
| def change_args_to_formal_for_all_translators(args_name, args_translators): | |||||
| """ | |||||
| Change args to formal for all translators provided. | |||||
| Args: | |||||
| args_name (str): The name of args to be changing. | |||||
| args_translators (ArgsTranslation): The args to be changed in args translators. | |||||
| """ | |||||
| if isinstance(args_name, str): | |||||
| args_name = [args_name] | |||||
| if isinstance(args_translators, ArgsTranslation): | |||||
| args_translators = [args_translators] | |||||
| for arg in args_name: | |||||
| for args_t in args_translators: | |||||
| args_t.set_actual_args_backup() | |||||
| args_t.make_actual_arg_to_formal(arg) | |||||
| @@ -0,0 +1,630 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Main Generator module.""" | |||||
| import copy | |||||
| from collections import OrderedDict | |||||
| from .scope_utils import Scope | |||||
| from .node_struct import NodeStruct | |||||
| from .module_struct import ModuleStruct | |||||
| from .args_translator import ArgsTranslationHelper | |||||
| from ..common.global_context import GlobalContext | |||||
| from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT | |||||
| class Singleton(type): | |||||
| """Metaclass to make the generator to be single instance.""" | |||||
| _instances = {} | |||||
| def __call__(cls, *args, **kwargs): | |||||
| if cls not in cls._instances: | |||||
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||||
| return cls._instances[cls] | |||||
| class CodeStruct: | |||||
| """ | |||||
| Define the Code template for each module generated in the final output. | |||||
| Each module has only one CodeStruct to its pattern. | |||||
| """ | |||||
| GLOBAL_CONTEXT = GlobalContext() | |||||
| NOT_IN_SCOPE_OPT = dict() | |||||
| def __init__(self, struct, repeated_submodules=None): | |||||
| """Initialize the CodeStruct.""" | |||||
| self.output_order = None # output order | |||||
| self.input = None # opt_var_name for prev. node | |||||
| self.extra_input = list() # extra_input(s) at construct method args | |||||
| self.output = None # opt_var_name for next node | |||||
| self.extra_output = list() # extra_output(s) | |||||
| self.extra_comment = None # comments for this code line / block. | |||||
| self.code_line_list = list() # list of code line, a item is a line. | |||||
| self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module | |||||
| self.formal_args_collections = None | |||||
| if isinstance(struct, NodeStruct): | |||||
| self.output_order = struct.topo_idx | |||||
| if isinstance(struct, ModuleStruct): | |||||
| self.output_order = struct.head_nd_struct_index | |||||
| self._generate_from_module_struct(struct, repeated_submodules) | |||||
| def _add_line(self, s): | |||||
| """Add line of code.""" | |||||
| self.code_line_list.append(s) | |||||
| @property | |||||
| def new_line(self): | |||||
| """Return last generated line.""" | |||||
| try: | |||||
| return self.code_line_list[-1] | |||||
| except IndexError: | |||||
| return "" | |||||
| @new_line.setter | |||||
| def new_line(self, s): | |||||
| """Make a new line.""" | |||||
| self._add_line(s) | |||||
| def _generate_from_module_struct(self, md_struct, repeated_submodules): | |||||
| """ | |||||
| Generate the code of current Module Struct, collecting data from submodules. | |||||
| Args: | |||||
| md_struct (ModuleStruct): The ModuleStruct which generates codes. | |||||
| repeated_submodules (dict): The dict contains all submodules which use repeatedly. | |||||
| Can get this dict from generator. | |||||
| """ | |||||
| # Define tmp var for code generation. | |||||
| opt_var_name_records = dict() # now only support multiple outputs within same scope. | |||||
| return_value_records = dict() # save returned values for successor nodes/modules use. | |||||
| # Define Module header code line below | |||||
| if md_struct.pattern_id != -1: | |||||
| class_name = f"Module{md_struct.pattern_id}" | |||||
| else: | |||||
| class_name = "Model" | |||||
| # define a class declaration | |||||
| self.new_line = f"class {class_name}(nn.Cell):" | |||||
| # Get all formal args from nodes | |||||
| module_def_args = ['self'] | |||||
| if md_struct.args_translator.actual_args: | |||||
| for actual in md_struct.args_translator.actual_args.keys(): | |||||
| module_def_args.append(actual) | |||||
| if md_struct.args_translator.formal_args: | |||||
| for formal in md_struct.args_translator.formal_args.keys(): | |||||
| module_def_args.append(formal) | |||||
| # Collect extra inputs and outputs | |||||
| # For code line in init & construct blocks | |||||
| init_lines = list() | |||||
| cons_lines = list() | |||||
| for (_, struct) in md_struct.get_generate_order(): | |||||
| if isinstance(struct, NodeStruct): # Generate code line for Node. | |||||
| code_line_init = struct.code_line_in_init() | |||||
| code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records) | |||||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | |||||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | |||||
| # add extra tensor | |||||
| if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: | |||||
| code_extra_tensor = struct.add_extra_tensor() | |||||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}") | |||||
| # record opt_var_name for succ nodes input in same scope. | |||||
| target_onnx_name = struct.graph_node_ref.successor_nodes | |||||
| for name in target_onnx_name: | |||||
| if opt_var_name_records.get(name): | |||||
| opt_var_name_records.get(name).append(code_line_construct[0]) | |||||
| else: | |||||
| opt_var_name_records[name] = [code_line_construct[0]] | |||||
| if struct.successor_nodes_names_external: | |||||
| for ret_user in struct.successor_nodes_names_external: | |||||
| if return_value_records.get(ret_user) is not None: | |||||
| return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0])) | |||||
| else: | |||||
| return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])] | |||||
| elif isinstance(struct, ModuleStruct): | |||||
| # check if this instance generated CodeStruct | |||||
| if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None: | |||||
| CodeStruct(struct, repeated_submodules) | |||||
| code_line_init = struct.code_line_in_init() | |||||
| code_line_construct = struct.code_line_in_construct(inputs=struct.matched_inputs) | |||||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | |||||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | |||||
| # record opt_var_name for succ nodes input in same scope. | |||||
| target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes | |||||
| for name in target_onnx_name: | |||||
| if opt_var_name_records.get(name): | |||||
| opt_var_name_records.get(name).append(code_line_construct[0]) | |||||
| else: | |||||
| opt_var_name_records[name] = [code_line_construct[0]] | |||||
| # record submodule's local return map for following nodes / submodules use | |||||
| if struct.external_successor_local_returns_map: | |||||
| for ret_user, _ in struct.external_successor_local_returns_map.items(): | |||||
| if return_value_records.get(ret_user) is not None: | |||||
| # mulitple returns of a node may need modifiy the index. | |||||
| return_value_records[ret_user].append((struct.identifier, code_line_construct[0])) | |||||
| else: | |||||
| return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])] | |||||
| else: | |||||
| raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") | |||||
| # define header of init block | |||||
| self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):" | |||||
| self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()" | |||||
| # add init code lines to code line list. | |||||
| self.code_line_list += init_lines | |||||
| self.new_line = f"{NEW_LINE * 2}" | |||||
| # define header of construct block | |||||
| inputs = ['self'] + list(md_struct.construct_header_x.keys()) | |||||
| self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):" | |||||
| # add construct code lines to code line list. | |||||
| self.code_line_list += cons_lines | |||||
| # define returns | |||||
| returns = [] | |||||
| if md_struct.external_successor_local_returns_map: | |||||
| ret = list(md_struct.external_successor_local_returns_map.values()) | |||||
| for r in ret: | |||||
| if isinstance(r, tuple): # results return with index nth output | |||||
| returns.append(r[0]) | |||||
| else: | |||||
| returns.append(r) | |||||
| returns = list(set(returns)) | |||||
| else: | |||||
| returns = [code_line_construct[0]] | |||||
| self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" | |||||
| self.new_line = f"{NEW_LINE * 2}" | |||||
| self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self | |||||
| class Generator(metaclass=Singleton): | |||||
| """The generator controls all routines of code generation.""" | |||||
| def __init__(self): | |||||
| """Init the generator.""" | |||||
| # define basic attributes | |||||
| self.framework = None | |||||
| # define MUST have params | |||||
| self._node_struct_collections = OrderedDict() | |||||
| self._module_struct_collections = OrderedDict() | |||||
| self._module_depth_max = 0 | |||||
| self._module_depth_min = 0 | |||||
| # define intermediate var. during conversion | |||||
| self._module_map = OrderedDict() | |||||
| self._global_context = GlobalContext() | |||||
| self._global_context.node_struct_collections = self._node_struct_collections | |||||
| self._repeated_submodules = set() | |||||
| def _form_bottom_submodule(self): | |||||
| """Form the basic submodules, which only contains nodes.""" | |||||
| # Form module map | |||||
| curr_scope_path = None | |||||
| nd_struct_list_in_submodule = [] | |||||
| for nd_struct in self.node_structs.values(): | |||||
| idx = nd_struct.topo_idx | |||||
| if curr_scope_path is None: | |||||
| curr_scope_path = nd_struct.scope.path | |||||
| nd_struct_list_in_submodule.append((idx, nd_struct)) | |||||
| elif curr_scope_path == nd_struct.scope.path: | |||||
| nd_struct_list_in_submodule.append((idx, nd_struct)) | |||||
| else: # curr_scope_path changed | |||||
| # save this submodule | |||||
| if self._module_map.get(str(curr_scope_path)) is not None: | |||||
| self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule | |||||
| else: | |||||
| self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule | |||||
| # create a new one | |||||
| curr_scope_path = nd_struct.scope.path | |||||
| nd_struct_list_in_submodule = [(idx, nd_struct)] | |||||
| # save last submodule | |||||
| if self._module_map.get(str(curr_scope_path)) is not None: | |||||
| self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule | |||||
| else: | |||||
| self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule | |||||
| # Form bottom modules' ModuleStruct | |||||
| for scope_path_str, nd_struct_list in self._module_map.items(): | |||||
| self._module_struct_collections[scope_path_str] = ModuleStruct(nd_struct_list) | |||||
| def _list_repeated_submodules(self) -> OrderedDict: | |||||
| """ | |||||
| Return the repeated submodules by its depth and num. | |||||
| For example, "Model/Module3_3" will return {1:(3)} | |||||
| Return: | |||||
| OrderedDict, a dict contains collections of repeated submodules. | |||||
| """ | |||||
| ret = OrderedDict() | |||||
| for depth_control in range(self._module_depth_max, 0, -1): | |||||
| repeated_submodules_at_this_depth = set() | |||||
| for scope_path in self._module_map.keys(): | |||||
| path = Scope.path_str_to_list(scope_path) | |||||
| if len(path) < depth_control: | |||||
| continue | |||||
| else: # depth control within path length | |||||
| module_num = path[depth_control - 1][0] | |||||
| repeated_submodules_at_this_depth.add(module_num) | |||||
| ret[depth_control] = repeated_submodules_at_this_depth | |||||
| self._repeated_submodules = ret | |||||
| return ret | |||||
| def _compare_with_base_parameters(self, nd_struct_list): | |||||
| """ | |||||
| Compare the parameter to check if it should be a formal args. | |||||
| Args: | |||||
| nd_struct_list (list): A list of NodeStructs which contains | |||||
| all same nodes in repeated submodules. | |||||
| Return: | |||||
| set, a set of all formal args in this node. | |||||
| """ | |||||
| formal_args = set() | |||||
| if len(nd_struct_list) < 2: | |||||
| return formal_args | |||||
| (_, base_nd_struct) = nd_struct_list[0] | |||||
| for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param | |||||
| for (_, nd_struct) in nd_struct_list[1:]: | |||||
| compared_value = nd_struct.fragment.actual_args.get(base_parameter) | |||||
| if compared_value == base_value: | |||||
| continue | |||||
| else: | |||||
| formal_args.add(base_parameter) | |||||
| break | |||||
| return formal_args | |||||
| def _list_formal_parameters_in_a_module(self, module_filter_return): | |||||
| """ | |||||
| Find all formal args / params from nodes in a module. | |||||
| Args: | |||||
| module_filter_return (dict): The filtered results from the module_map_filter. | |||||
| Return: | |||||
| list, a list of sets or None indicates all formal args of each node in the module in order. | |||||
| """ | |||||
| formal_params_list = list() | |||||
| transposed = [list(e) for e in zip(*module_filter_return)] | |||||
| for operation in transposed: | |||||
| formal_parameters = self._compare_with_base_parameters(operation) | |||||
| if formal_parameters: | |||||
| formal_params_list.append(formal_parameters) | |||||
| else: | |||||
| formal_params_list.append(None) | |||||
| return formal_params_list | |||||
| def _list_formal_parameters(self, repeated_submodules) -> dict: | |||||
| """ | |||||
| Return a list of formal parameters in each submodule. | |||||
| Args: | |||||
| repeated_submodules (dict): A dict which contains repeated submodules, | |||||
| acquire this dict from _list_repeated_submodules() | |||||
| Return: | |||||
| OrderedDict, a dict with each submodule's formal args. | |||||
| Example: | |||||
| A return for ResNet50 could be: | |||||
| {0: # submoodule 0 | |||||
| [set('stride', 'in_channels', 'out_channels'), # args of the first node to be set as formal | |||||
| set('num_features'), # args of the second node to be set as formal | |||||
| None, # args of third node to be set as formal, which does not have | |||||
| set('in_channels', 'out_channels'), | |||||
| set('num_features'), | |||||
| None | |||||
| ]}, | |||||
| {3: # submodule 3 | |||||
| [...], | |||||
| {5: # submodule 5 | |||||
| []} # empty returns means no nodes or it's a parent module of submodules. | |||||
| } | |||||
| """ | |||||
| formal_args_in_each_submodule = OrderedDict() | |||||
| checked_module = set() | |||||
| # filter module_map by submodule_num (without depth) | |||||
| for _, module_nums in repeated_submodules.items(): | |||||
| for module_num in module_nums: | |||||
| if module_num in checked_module: # module already checked | |||||
| continue | |||||
| else: | |||||
| checked_module.add(module_num) | |||||
| map_filtered = self.module_map_filter(module_num=module_num) | |||||
| formal_args_in_this_module = self._list_formal_parameters_in_a_module(map_filtered) | |||||
| formal_args_in_each_submodule[module_num] = formal_args_in_this_module | |||||
| return formal_args_in_each_submodule | |||||
| def _add_submodule_to_parent(self): | |||||
| """ | |||||
| Recursively add all submodule to its parent module until Main module. | |||||
| Note: | |||||
| This function deepcopy the first node of the submodule, and reset its params as parent module. | |||||
| """ | |||||
| depth = self._module_depth_max | |||||
| while depth > 0: | |||||
| for (scope_path_str, md_struct) in self.module_structs.copy().items(): | |||||
| if scope_path_str == '[]': | |||||
| continue # is main module, skip | |||||
| if md_struct.scope_depth != depth: | |||||
| continue # skip all submodules not at current depth | |||||
| md_struct_scope = copy.deepcopy(md_struct.identifier) | |||||
| md_struct_scope.pop() | |||||
| parent_scope = md_struct_scope | |||||
| # 1. check if this module has parent module | |||||
| parent_md_struct = self.module_structs.get(str(parent_scope)) | |||||
| if parent_md_struct is not None: | |||||
| # 1A. has parent, directly add md_struct to its parent ModuleStruct. | |||||
| parent_md_struct.add_submodule(md_struct) | |||||
| self.module_structs[str(parent_scope)] = parent_md_struct | |||||
| else: | |||||
| # 1B. not has parent, generate a new ModuleStruct | |||||
| parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module | |||||
| # rewrite parent md struct | |||||
| parent_md_struct.reset_as_parent() | |||||
| parent_md_struct.add_submodule(md_struct) | |||||
| self.module_structs[str(parent_scope)] = parent_md_struct | |||||
| sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections | |||||
| self._global_context.add_module_struct(sub.pattern_id, sub) | |||||
| depth -= 1 | |||||
| def _recursive_form_module(self): | |||||
| """Main routine in generator to build modules from bottom to top.""" | |||||
| # 1. List repeated submodules | |||||
| repeated_submodules = self._list_repeated_submodules() | |||||
| # 2. List reused parameters | |||||
| formal_parameters = self._list_formal_parameters(repeated_submodules) | |||||
| # 3. Build base subdmodules and set in/ext params translation | |||||
| for module_struct in self.module_structs.values(): | |||||
| if module_struct.pattern_id == -1: # is main module | |||||
| continue | |||||
| formal_args = formal_parameters.get(module_struct.pattern_id) | |||||
| module_struct.update_args_translation_list(formal_args) | |||||
| # 4. Form parent modules | |||||
| md_collection_len = len(self.module_structs.keys()) | |||||
| len_changes = True | |||||
| while len_changes: | |||||
| self._add_submodule_to_parent() | |||||
| new_len = len(self.module_structs.keys()) | |||||
| if md_collection_len != new_len: | |||||
| md_collection_len = new_len | |||||
| else: | |||||
| len_changes = False | |||||
| # 5. Update all translated args from module map | |||||
| self._update_all_modules_args_translator() | |||||
| # 6. Update all nodes and moudles input/output | |||||
| self.module_structs.get('[]').allocate_construct_header_x() | |||||
| self.module_structs.get('[]').collect_returns() | |||||
| def _update_all_modules_args_translator(self): | |||||
| """Update all modules' args translators.""" | |||||
| done_submodule = set() | |||||
| for depth in range(self._module_depth_max, 0, -1): | |||||
| # check modules from bottom to top | |||||
| repeated_submodules = copy.deepcopy(self._repeated_submodules) | |||||
| repeated_modules = repeated_submodules.get(depth) | |||||
| if depth is None: | |||||
| continue | |||||
| for pattern_id in repeated_modules: | |||||
| if pattern_id in done_submodule: | |||||
| continue | |||||
| # get all md_structs by same pattern | |||||
| md_list = self._global_context.module_structs.get(pattern_id) | |||||
| self._take_formal_args_from_updated_submodules(md_list) | |||||
| args_translators = self.get_args_translator_from_module_structs_list(md_list) | |||||
| formal_args_list = ArgsTranslationHelper.find_formal_args_in_modules(args_translators) | |||||
| changed_args_translators = self.get_args_translator_from_module_structs_list( | |||||
| md_list, exclude_root_son=True) | |||||
| ArgsTranslationHelper.change_args_to_formal_for_all_translators( | |||||
| formal_args_list, changed_args_translators) | |||||
| done_submodule.add(pattern_id) | |||||
| def _take_formal_args_from_updated_submodules(self, md_list): | |||||
| """ | |||||
| Take formal args from provided modules' nodes and submodules. | |||||
| Args: | |||||
| md_list (list): A list of ModuleStruct. | |||||
| """ | |||||
| if isinstance(md_list, ModuleStruct): | |||||
| md_list = [md_list] | |||||
| for md in md_list: | |||||
| md.args_translator.take_formal_args_from_nodes_and_submodules(md.get_all_sub_translators()) | |||||
| def _update_module_depth_max(self, nd_struct: NodeStruct): | |||||
| """ | |||||
| Update the Generator attribute module_depth_max, which is the maximum depth in the Model. | |||||
| Args: | |||||
| nd_struct (NodeStruct): NodeStruct to be checked its depth. | |||||
| """ | |||||
| depth = nd_struct.scope.depth | |||||
| if isinstance(depth, int): | |||||
| if depth > self._module_depth_max: | |||||
| self._module_depth_max = depth | |||||
| else: | |||||
| raise TypeError("Unable to update global depth due to TypeError in NodeStruct.scope.depth") | |||||
| def add_node(self, node_identifier, node_instance=None, node_fragment=None, mapper_dict=None): | |||||
| """ | |||||
| Add Node information to the generator. | |||||
| Args: | |||||
| node_identifier (str): The unique identifier for the node passed in. | |||||
| node_instance (GraphNode): The GraphNode instance of each node. | |||||
| node_fragment (NodeFragment): The NodeFragment instance of this node passed in. | |||||
| mapper_dict (dict): The dict contains converted attributes from mapper. | |||||
| """ | |||||
| if node_identifier is None: | |||||
| raise ValueError("Node Identifier should not be None.") | |||||
| self._global_context.node_fragments[node_identifier] = node_fragment | |||||
| args = [] | |||||
| if node_instance is not None: | |||||
| args.append(node_instance) | |||||
| if mapper_dict is not None: | |||||
| args.append(mapper_dict) | |||||
| if node_fragment is not None: | |||||
| args.append(node_fragment) | |||||
| nd_struct = self.node_structs.get(node_identifier) | |||||
| if nd_struct: # NodeStruct already exists | |||||
| nd_struct.update(args) | |||||
| else: # create new Node Struct | |||||
| nd_struct = NodeStruct(args) | |||||
| nd_struct.identifier = node_identifier | |||||
| self._update_module_depth_max(nd_struct) | |||||
| self.node_structs[node_identifier] = nd_struct | |||||
| @property | |||||
| def node_structs(self): | |||||
| """Return all NodeStructs in this model.""" | |||||
| return self._node_struct_collections | |||||
| @property | |||||
| def module_structs(self): | |||||
| """Return all ModuleStructs in this model.""" | |||||
| return self._module_struct_collections | |||||
| def generate(self): | |||||
| """ | |||||
| Generate the final script file. | |||||
| Returns: | |||||
| list, a list of each line in script file. | |||||
| """ | |||||
| self._form_bottom_submodule() | |||||
| self._recursive_form_module() | |||||
| code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | |||||
| return code.code_line_list | |||||
| def get_node_struct(self, node_identifier): | |||||
| """ | |||||
| Get specific NodeStruct by node_identifier. | |||||
| Args: | |||||
| node_identifier (str): The node unique identifier. | |||||
| Return: | |||||
| NodeStruct, the node's NodeStruct. | |||||
| """ | |||||
| return self._node_struct_collections.get(node_identifier, None) | |||||
| def get_module_struct(self, module_identifier): | |||||
| """ | |||||
| Get specific ModuleStruct by module_identifier. | |||||
| Args: | |||||
| module_identifier (str): The module unique identifier. | |||||
| Return: | |||||
| ModuleStruct, the node's ModuleStruct. | |||||
| """ | |||||
| return self._module_struct_collections.get(module_identifier, None) | |||||
| def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id): | |||||
| """ | |||||
| Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern. | |||||
| Args: | |||||
| pattern_id (int): The pattern id the returned ModuleSturct is. | |||||
| under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is. | |||||
| Returns: | |||||
| list, a list of MoudleStructs has the same pattern_id and the same parents' pattern. | |||||
| """ | |||||
| if not pattern_id: | |||||
| raise ValueError("pattern_id is necessary to get the module struct.") | |||||
| if not under_parent_pattern_id: | |||||
| raise ValueError("under_parent_pattern_id is necessary to get the module struct.") | |||||
| ret = [] | |||||
| md_list = self._global_context.module_structs.get(pattern_id) | |||||
| for md in md_list: | |||||
| if md.parent_id == under_parent_pattern_id: | |||||
| ret.append(md) | |||||
| return ret | |||||
| def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): | |||||
| """ | |||||
| Return a list of args translators which belongs to given module structs. | |||||
| Args: | |||||
| md_list (list): A list of ModuleStruct. | |||||
| exclude_root_son (Bool): If the returned result should include args translator belongs to | |||||
| modules under the Main module. | |||||
| Returns: | |||||
| list, a list of args translators which belongs to given module structs. | |||||
| """ | |||||
| ret = [] | |||||
| for md in md_list: | |||||
| if exclude_root_son and md.parent_id == -1: | |||||
| continue | |||||
| if md.args_translator is not None: | |||||
| ret.append(md.args_translator) | |||||
| return ret | |||||
| def module_map_filter(self, depth=None, module_num=None, uid=None): | |||||
| """ | |||||
| Filter the module map by given conditions. | |||||
| Args: | |||||
| depth (int): Scope depth. | |||||
| module_num (int): The submodule number. | |||||
| uid (int): The unique identifier of a submodule. | |||||
| Return: | |||||
| list, list of NodeStruct list of each submodule. | |||||
| """ | |||||
| ret = list() | |||||
| for scope_path, nd_struct_list in self._module_map.items(): | |||||
| path = Scope.path_str_to_list(scope_path) | |||||
| if not path: # skip main | |||||
| continue | |||||
| # if depth not equals to the indicated depth, skip | |||||
| if depth is not None and len(path) != depth: | |||||
| continue | |||||
| scope_at_depth = path[-1] | |||||
| (m_num, m_uid) = scope_at_depth | |||||
| if uid is not None: | |||||
| if m_num == module_num and m_uid == uid: | |||||
| ret.append(nd_struct_list) | |||||
| else: | |||||
| if m_num == module_num: | |||||
| ret.append(nd_struct_list) | |||||
| return ret | |||||
| @@ -0,0 +1,710 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define a struct for module converted and save all required information here.""" | |||||
| from collections import OrderedDict | |||||
| from .node_struct import NodeStruct | |||||
| from .scope_utils import Scope | |||||
| from ..common.utils import get_dict_key_by_value | |||||
| from .args_translator import ArgsTranslation | |||||
| from ..common.code_fragment import ModuleFragment | |||||
| from ..common.global_context import GlobalContext | |||||
| from ..hierarchical_tree.name_mgr import LocalVarNameMgr | |||||
| class ModuleStruct: | |||||
| """ | |||||
| Define a module struct which stores all info. to generate statement. | |||||
| Args: | |||||
| args (list): A list of node structs. | |||||
| """ | |||||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||||
| def __init__(self, nd_struct_list): | |||||
| """Init. a module by NodeStructs.""" | |||||
| self.pattern_id = -1 # pattern num, -1 as Main module | |||||
| self.pattern_uid = -1 # unique module id for this pattern | |||||
| self.parent_id = None # parent's pattern num | |||||
| self.parent_uid = None # parent's pattern module unique id | |||||
| self.initialized = False | |||||
| self.identifier = None | |||||
| self.module_name = None | |||||
| self.scope_depth = None | |||||
| self.head_nd_struct = None | |||||
| self.head_nd_struct_index = None | |||||
| self.tail_nd_struct = None | |||||
| self.tail_nd_struct_index = None | |||||
| self._node_structs = list() | |||||
| self._module_structs = list() | |||||
| self._fragment = None | |||||
| self._args_translator = None | |||||
| self._setting = None | |||||
| self._parent_module_struct = None | |||||
| # only store original formal args name, not global | |||||
| self._nodes_structs_formal_args_list = list() | |||||
| # only store translated (globalized) formal args | |||||
| self._nodes_structs_formal_args_translated_list = list() | |||||
| # define other settings here | |||||
| self._node_args_translation_list = list() | |||||
| self._var_name_mgr = LocalVarNameMgr() | |||||
| self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name | |||||
| self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct | |||||
| self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name | |||||
| # key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name) | |||||
| self.outputs_collection = dict() | |||||
| self.matched_inputs = list() # Matched inputs will can be directly used by code line generation | |||||
| # key is ext. succ node onnx name, value is local opt_var | |||||
| self.external_successor_local_returns_map = OrderedDict() | |||||
| # key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level | |||||
| self.outputs_collection = dict() | |||||
| # start initialization | |||||
| if not self.initialized: | |||||
| self._init_module(nd_struct_list) | |||||
| else: | |||||
| self._update_module(nd_struct_list) | |||||
| # assign this module reference to node | |||||
| for (_, nd_struct) in nd_struct_list: | |||||
| nd_struct.parent_module_struct = self | |||||
| def reset_as_parent(self): | |||||
| """ | |||||
| Reset all attributes and filled as a parent module of this module. | |||||
| Note: | |||||
| This function must be called only after a deepcopy of this instance! | |||||
| """ | |||||
| self.identifier.pop() | |||||
| self.scope_depth = self.scope_depth - 1 | |||||
| self._set_pattern_id() | |||||
| self._find_parent_module() | |||||
| self.module_name = Scope.scope_to_module_name(self.identifier) | |||||
| self._node_structs = list() | |||||
| self._module_structs = list() | |||||
| self._fragment = None | |||||
| self._args_translator = None | |||||
| self.init_args_translator() | |||||
| self._setting = None | |||||
| self._parent_module_struct = None | |||||
| self._nodes_structs_formal_args_list = list() | |||||
| self._node_args_translation_list = list() | |||||
| def _set_pattern_id(self): | |||||
| """Set pattern id which matches the module fragment pattern.""" | |||||
| if not self.initialized: | |||||
| return | |||||
| if self.scope_depth < 1: | |||||
| self.pattern_id = -1 | |||||
| self.pattern_uid = -1 | |||||
| return | |||||
| self.pattern_id = self.identifier[-1][0] | |||||
| self.pattern_uid = self.identifier[-1][1] | |||||
| def _init_module(self, nd_struct_list): | |||||
| """Init this ModuleStruct by a list of Nodes.""" | |||||
| (nd_topo_idx, nd_struct) = nd_struct_list[0] | |||||
| self.identifier = nd_struct.scope.path | |||||
| self.module_name = nd_struct.scope.to_str | |||||
| self.scope_depth = nd_struct.scope.depth | |||||
| self.head_nd_struct = nd_struct | |||||
| self.head_nd_struct_index = nd_topo_idx | |||||
| self.tail_nd_struct = nd_struct_list[-1][1] | |||||
| self.tail_nd_struct_index = nd_struct_list[-1][0] | |||||
| self._node_structs = nd_struct_list | |||||
| self.initialized = True | |||||
| self._set_pattern_id() | |||||
| self._find_parent_module() | |||||
| self.init_args_translator() | |||||
| def _update_module(self, nd_struct_list): | |||||
| """Update the ModuleStruct attributes from a list of Nodes.""" | |||||
| (nd_topo_idx_head, nd_struct_head) = nd_struct_list[0] | |||||
| (nd_topo_idx_tail, nd_struct_tail) = nd_struct_list[-1] | |||||
| if self.identifier != nd_struct_head.scope.path: | |||||
| raise ValueError("Unable to update this module struct {} due to different identifier {}".format( | |||||
| self.identifier, nd_struct_head.scope.path)) | |||||
| if nd_topo_idx_head < self.head_nd_struct_index: | |||||
| self.head_nd_struct_index = nd_topo_idx_head | |||||
| self.head_nd_struct = nd_struct_head | |||||
| if nd_topo_idx_tail > self.tail_nd_struct_index: | |||||
| self.tail_nd_struct_index = nd_topo_idx_tail | |||||
| self.tail_nd_struct = nd_struct_tail | |||||
| self._node_structs += nd_struct_list | |||||
| def _find_parent_module(self): | |||||
| """Set the parent's module pattern and uid.""" | |||||
| if not self.initialized: | |||||
| return | |||||
| if self.scope_depth == 0: # is Main Module | |||||
| pass | |||||
| elif self.scope_depth == 1: # parent pattern is Main module | |||||
| self.parent_id = -1 | |||||
| self.parent_uid = -1 | |||||
| else: # this is a submodule in a module | |||||
| (self.parent_id, self.parent_uid) = Scope.get_parent_module_num_and_uid( | |||||
| self.identifier) | |||||
| def __repr__(self): | |||||
| return str({ | |||||
| "address": hex(id(self)), | |||||
| "identifier": self.identifier, | |||||
| "parent": (self.parent_id, self.parent_uid), | |||||
| "name": self.module_name, | |||||
| "pattern": self.pattern_id, | |||||
| "scope_depth": self.scope_depth, | |||||
| "nd_idx_range": "{} -> {}".format(self.head_nd_struct_index, self.tail_nd_struct_index), | |||||
| "initialized": self.initialized | |||||
| }) | |||||
| def init_module_fragment(self): | |||||
| """Init the module fragment.""" | |||||
| if not self.initialized: | |||||
| return | |||||
| # check if fragment exists in global context | |||||
| op = "Module{}".format(self.pattern_id) | |||||
| if op == "Module-1": # reset as Main Model's op name | |||||
| op = "Model" | |||||
| frag = GlobalContext().get_module_fragment(op) | |||||
| if frag is not None: # use exists fragment | |||||
| self._fragment = frag | |||||
| else: | |||||
| frag = ModuleFragment(operation=op, | |||||
| actual_args=None, | |||||
| input_shape=None, | |||||
| output_shape=None, | |||||
| settings=None) | |||||
| self._fragment = frag | |||||
| # set fragment pattern | |||||
| self._fragment.pattern = self._node_structs | |||||
| GlobalContext().add_module_fragment(op, frag) | |||||
| def init_args_translator(self): | |||||
| """Initialize the Args Translator for the module.""" | |||||
| var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid) | |||||
| self._args_translator = ArgsTranslation(None, var_name, None) | |||||
| def update_module_fragment(self): | |||||
| """Update this module's fragment.""" | |||||
| if self._fragment is None: | |||||
| return | |||||
| # update input output shape | |||||
| self._fragment.input_shape = self.head_nd_struct.fragment.input_shape | |||||
| self._fragment.output_shape = self.tail_nd_struct.fragment.output_shape | |||||
| # update formal args | |||||
| self._fragment.formal_args.update(self._args_translator.formal_args) | |||||
| self._fragment.formal_args_value.update(self._args_translator.formal_args_values) | |||||
| # update actual args | |||||
| self._fragment.actual_args.update(self._args_translator.actual_args) | |||||
| # update others.. | |||||
| def add_submodule(self, md_structs): | |||||
| """ | |||||
| Add another module struct(s) to this ModuleStruct. | |||||
| Args: | |||||
| md_structs ([ModuleStruct, list]): a (list) ModuleStruct to be added in this ModuleStruct. | |||||
| """ | |||||
| tail_md = md_structs | |||||
| if isinstance(md_structs, ModuleStruct): | |||||
| md_structs.args_translator.take_formal_args_from_nodes_and_submodules(md_structs.get_all_sub_translators()) | |||||
| self._module_structs.append(md_structs) | |||||
| md_structs.parent_module_struct = self | |||||
| elif isinstance(md_structs, list): | |||||
| for md_s in md_structs: | |||||
| md_s.args_translator.take_formal_args_from_nodes_and_submodules(md_s.get_all_sub_translators()) | |||||
| md_s.parent_module_struct = self | |||||
| self._module_structs += md_structs | |||||
| tail_md = md_structs[-1] | |||||
| else: | |||||
| raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format( | |||||
| type(md_structs))) | |||||
| # update tail node and index | |||||
| if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: | |||||
| self.tail_nd_struct = tail_md.tail_nd_struct | |||||
| self.tail_nd_struct_index = tail_md.tail_nd_struct_index | |||||
| def _update_formal_args_for_all_nd_structs(self): | |||||
| """ | |||||
| Init nodes' args translator and find formal args. | |||||
| And collect nodes' formal args. | |||||
| """ | |||||
| if len(self._node_args_translation_list) != len(self._node_structs): | |||||
| raise ValueError( | |||||
| "ModuleStruct cannot update nodes' formal args due to length inconsistent.") | |||||
| for idx, (_, nd_struct) in enumerate(self._node_structs): | |||||
| formal_arg_of_this_node = self._node_args_translation_list[idx] | |||||
| # update var_name to ensure all node names' are unique in a module. | |||||
| nd_struct.update_var_name(idx) | |||||
| nd_struct.init_args_translator(formal_arg_of_this_node) | |||||
| if nd_struct.args_translator is not None: | |||||
| self._nodes_structs_formal_args_list.append( | |||||
| nd_struct.args_translator.formal_args_values) | |||||
| else: | |||||
| self._nodes_structs_formal_args_list.append(None) | |||||
| def update_args_translation_list(self, formal_args): | |||||
| """ | |||||
| Receive a list of args name to be changed to formal args, and change them. | |||||
| Args: | |||||
| formal_args (list[str]): a list of args name to be changed to formal args. | |||||
| """ | |||||
| self._node_args_translation_list = formal_args | |||||
| self._update_formal_args_for_all_nd_structs() | |||||
| def get_all_sub_translators(self): | |||||
| """ | |||||
| Return a list of args_translators of submodules / nodes affiliated to this module. | |||||
| Note: | |||||
| The order of returned list is followed by the actual topological order. | |||||
| Returns: | |||||
| list, a list of args_translators. | |||||
| """ | |||||
| ret = [] | |||||
| for (_, struct) in self.get_generate_order(): | |||||
| if struct.args_translator is not None: | |||||
| ret.append(struct.args_translator) | |||||
| return ret | |||||
| def get_generate_order(self): | |||||
| """ | |||||
| Return the order of generated code by index. | |||||
| Return: | |||||
| list, a list of reference of node_struct or module_struct. | |||||
| """ | |||||
| ret = list() | |||||
| if not self._module_structs: | |||||
| return self._node_structs | |||||
| # Generate a list of tuple (idx, module_structs) | |||||
| for md_struct in self._module_structs: | |||||
| ret.append((md_struct.head_nd_struct_index, md_struct)) | |||||
| if self.node_structs: | |||||
| ret += self.node_structs | |||||
| ret.sort(key=lambda x: x[0]) | |||||
| return ret | |||||
| def code_line_in_init(self): | |||||
| """ | |||||
| Initialization line of code in module init block. | |||||
| Args: | |||||
| override_formal_val (dict): Indicate which args should be renamed for passing value from upper level. | |||||
| """ | |||||
| left = "self.{}".format(self.ms_var_name) | |||||
| args_list = list() | |||||
| # Load args in init statement. | |||||
| if self._args_translator is not None: # from args_translator | |||||
| if self._args_translator.actual_args: # load actual args | |||||
| args_list += self._args_translator.actual_args_to_str_list | |||||
| elif self._args_translator.actual_args_backup and self.parent_id == -1: | |||||
| # For modules repeated in multiple levels, the module under main model should | |||||
| # not use formal args as it is unnecessary -> load from actual args backup | |||||
| args_list += self._args_translator.actual_args_backup_to_str_list | |||||
| args_list += self._args_translator.formal_args_to_str_list # load from formal args | |||||
| else: | |||||
| args_list += self._fragment.actual_args | |||||
| right = f"{self.class_name}({', '.join(args_list)})" | |||||
| return (left, right) | |||||
| def code_line_in_construct(self, inputs=None): | |||||
| """Construct line of code in module construct block.""" | |||||
| # check number of outputs this module has | |||||
| opt_var_name_in_module = list(self.external_successor_local_returns_map.values()) | |||||
| num_output = len(set(opt_var_name_in_module)) | |||||
| if num_output == 1: # single output | |||||
| left = f"{self.ms_opt_var_name}" | |||||
| else: | |||||
| left = [f"{self.ms_opt_var_name}_{num}" for num in range(num_output)] | |||||
| if inputs is None and self.matched_inputs: | |||||
| inputs = self.matched_inputs | |||||
| if isinstance(inputs, str): | |||||
| inputs = [inputs] | |||||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | |||||
| return (left, right) | |||||
| @property | |||||
| def node_structs(self): | |||||
| """Return all node structs in this module.""" | |||||
| return self._node_structs | |||||
| @property | |||||
| def module_structs(self): | |||||
| """Return all module structs in this module.""" | |||||
| return self._module_structs | |||||
| @property | |||||
| def parent_module_struct(self): | |||||
| """Return this module's parent module struct.""" | |||||
| return self._parent_module_struct | |||||
| @parent_module_struct.setter | |||||
| def parent_module_struct(self, ref): | |||||
| """Set this modu;e's parent module struct.""" | |||||
| self._parent_module_struct = ref | |||||
| @property | |||||
| def args_translator(self): | |||||
| """Return the args translator.""" | |||||
| return self._args_translator | |||||
| @property | |||||
| def head_nd_struct_precursor_nodes_names(self) -> list: | |||||
| """Return head node's precursor nodes names.""" | |||||
| return self.head_nd_struct.precursor_nodes_names | |||||
| @property | |||||
| def head_nd_struct_precursor_nodes_structs(self) -> list: | |||||
| """Return head node's precursor nodes structs.""" | |||||
| return self.head_nd_struct.precursor_nodes_structs | |||||
| @property | |||||
| def tail_nd_struct_successor_nodes_names(self) -> list: | |||||
| """Return tail node's successor nodes names.""" | |||||
| return self.tail_nd_struct.successor_nodes_names | |||||
| @property | |||||
| def tail_nd_struct_successor_nodes_structs(self) -> list: | |||||
| """Return tail node's successor nodes structs.""" | |||||
| return self.tail_nd_struct.successor_nodes_structs | |||||
| @property | |||||
| def onnx_names_from_nodes(self) -> list: | |||||
| """Return all nodes onnx names in this module.""" | |||||
| ret = [] | |||||
| for (_, node) in self.node_structs: | |||||
| ret.append(node.onnx_name) | |||||
| return ret | |||||
| @property | |||||
| def onnx_names_from_submodules(self) -> list: | |||||
| """Return all nodes onnx names in submodules of this module.""" | |||||
| ret = [] | |||||
| for md_struct in self.module_structs: | |||||
| ret += md_struct.onnx_names | |||||
| return ret | |||||
| @property | |||||
| def onnx_names(self) -> list: | |||||
| """Return all nodes' onnx names which contained by this module.""" | |||||
| return self.onnx_names_from_nodes + self.onnx_names_from_submodules | |||||
| @property | |||||
| def external_precursor_nodes_names(self) -> list: | |||||
| """Return all precursors nodes names not in this module.""" | |||||
| ret = [] | |||||
| for _, struct in self.get_generate_order(): | |||||
| if isinstance(struct, NodeStruct): | |||||
| precursor_nodes_names = struct.precursor_nodes_names | |||||
| if isinstance(struct, ModuleStruct): | |||||
| precursor_nodes_names = struct.external_precursor_nodes_names | |||||
| for p_name in precursor_nodes_names: | |||||
| if p_name in self.onnx_names: | |||||
| continue | |||||
| ret.append(p_name) | |||||
| return ret | |||||
| @property | |||||
| def external_successor_nodes_names(self) -> list: | |||||
| """Return all precursors nodes names not in this module.""" | |||||
| ret = [] | |||||
| for _, struct in self.get_generate_order(): | |||||
| if isinstance(struct, NodeStruct): | |||||
| successor_nodes_names = struct.successor_nodes_names | |||||
| if isinstance(struct, ModuleStruct): | |||||
| successor_nodes_names = struct.external_successor_nodes_names | |||||
| for s_name in successor_nodes_names: | |||||
| if s_name in self.onnx_names: | |||||
| continue | |||||
| ret.append(s_name) | |||||
| return ret | |||||
| @property | |||||
| def class_name(self) -> str: | |||||
| """Return the class name for generating code of this module.""" | |||||
| if self.pattern_id == -1: | |||||
| return "Model" | |||||
| return "Module{}".format(self.pattern_id) | |||||
| @property | |||||
| def ms_var_name(self) -> str: | |||||
| """Return the variable name for generated code statement of this module.""" | |||||
| if self.pattern_id == -1: | |||||
| return "Model" | |||||
| return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower() | |||||
| @property | |||||
| def ms_opt_var_name(self) -> str: | |||||
| """Return the variable name for generated code statement of the output of this module.""" | |||||
| return "{}_opt".format(self.ms_var_name).lower() | |||||
| # The following part will be resetting nodes' external inputs for supporting multi-in/out | |||||
| # and should be called after generator.recursive_form_modules() | |||||
| def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): | |||||
| """ | |||||
| Mark the registered external inputs for code generation. | |||||
| Note: | |||||
| This function to be called by its parent (ModuleStruct). | |||||
| Args: | |||||
| header_x (str): The `x` in module construct header. | |||||
| onnx_precursor_node_name (str): The original onnx node name. | |||||
| """ | |||||
| if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None: | |||||
| raise ValueError("The input from {} has already registered. Check this Module \ | |||||
| {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) | |||||
| self.inputs_in_construct_header[onnx_precursor_node_name] = header_x | |||||
| def allocate_construct_header_x(self, force_x=None): | |||||
| """ | |||||
| Allocate the x in construct header for each external input. | |||||
| Args: | |||||
| force_x (str): Force the arg name to customized. | |||||
| """ | |||||
| local_x_name = 'x' | |||||
| if force_x: # name of x indicated by external | |||||
| local_x_name = force_x | |||||
| # set construct_header_x for current module | |||||
| allocated = set() | |||||
| for prec_name in self.external_precursor_nodes_names: | |||||
| if prec_name in allocated: | |||||
| continue | |||||
| x_name_in_construct_header = self._var_name_mgr.get_name(local_x_name) | |||||
| self.construct_header_x[x_name_in_construct_header] = prec_name | |||||
| allocated.add(prec_name) | |||||
| # Assign these inputs to nodes and submodules | |||||
| for _, struct in self.get_generate_order(): | |||||
| if isinstance(struct, NodeStruct): # register node's ext input | |||||
| self.reset_node_external_input_to_local(struct) | |||||
| self.register_node_output_to_module(struct) | |||||
| if isinstance(struct, ModuleStruct): # reg module's ext input | |||||
| if not struct.construct_header_x: | |||||
| struct.allocate_construct_header_x() | |||||
| self.reset_submodule_external_input_to_local(struct) | |||||
| self.register_submodule_output_to_module(struct) | |||||
| # remove parent module's ext. map if ext nodes in this module (no need return) | |||||
| for user_name in self.external_successor_local_returns_map.copy().keys(): | |||||
| if user_name in self.onnx_names: | |||||
| self.external_successor_local_returns_map.pop(user_name) | |||||
| def _match_node_inputs(self, struct): | |||||
| """Match node's inputs with its precursor nodes.""" | |||||
| for output_provider in struct.precursor_nodes_names: | |||||
| output_list = self.outputs_collection.get(output_provider) | |||||
| if output_list is None: | |||||
| # not in this module, check construct header | |||||
| for (self_x_name, self_output_provider) in self.construct_header_x.items(): | |||||
| if self_output_provider == output_provider: | |||||
| struct.matched_inputs.append(self_x_name) | |||||
| continue | |||||
| for output in output_list: | |||||
| (provider_succ, provider_closet_opt_var) = output | |||||
| if provider_closet_opt_var in struct.matched_inputs: | |||||
| continue # skip repeat | |||||
| if provider_succ == struct.onnx_name: | |||||
| struct.matched_inputs.append(provider_closet_opt_var) | |||||
| def _match_sub_modules_inputs(self): | |||||
| """ | |||||
| Match current module's submodules' inputs with corresponding outputs registered in current module. | |||||
| Description: | |||||
| The function matches these inputs by the following steps: | |||||
| 1. For each submodule in the current module, take submodule's construct header | |||||
| 2. Check submodule's construct header element requires an input from current module's | |||||
| construct header or outputs from other submodules. | |||||
| 3. If from current module's construct header, assign corresponding x to the submodule. | |||||
| If from other submodules, assign required submodule output name to the submodule. | |||||
| """ | |||||
| if not self.outputs_collection: | |||||
| return # skip first node | |||||
| for (_, struct) in self.get_generate_order(): | |||||
| if isinstance(struct, NodeStruct): | |||||
| self._match_node_inputs(struct) | |||||
| continue # skip node | |||||
| sub_construct_header = struct.construct_header_x | |||||
| for (_, output_provider) in sub_construct_header.items(): | |||||
| # check from outputs collection | |||||
| output_list = self.outputs_collection.get(output_provider) | |||||
| if output_list is None: | |||||
| # not in this module, need from current module construct header | |||||
| for (self_x_name, self_output_provider) in self.construct_header_x.items(): | |||||
| if self_output_provider == output_provider: | |||||
| struct.matched_inputs.append(self_x_name) | |||||
| continue | |||||
| for output in output_list: | |||||
| (provider_succ, provider_closet_opt_var) = output | |||||
| if provider_closet_opt_var in struct.matched_inputs: | |||||
| continue # skip repeat | |||||
| if provider_succ in struct.onnx_names: | |||||
| struct.matched_inputs.append(provider_closet_opt_var) | |||||
| def _append_to_outputs_collection(self, provider_name, val): | |||||
| """ | |||||
| Helper function to add a nodes or submodules outputs to current module return statement. | |||||
| Args: | |||||
| provider_name (str): The onnx name of the output provider. | |||||
| val (list[tuple]): A list of tuple which contains | |||||
| the output provider's successor name and its opt_var_name. | |||||
| """ | |||||
| exist_output = self.outputs_collection.get(provider_name) | |||||
| if isinstance(val, tuple): | |||||
| val = [val] | |||||
| if exist_output is None: # add new entry | |||||
| exist_output = list() | |||||
| exist_output += (val) | |||||
| self.outputs_collection[provider_name] = exist_output | |||||
| def collect_returns(self): | |||||
| """ | |||||
| Collect all nodes and submodules' returns in the module. | |||||
| Note: | |||||
| The logic is to collect the return from nodes and submodules by the order | |||||
| of topological index. | |||||
| For returns from a node, it will check if the return will be used externally. | |||||
| If external (external means the successor a.k.a the return user has different scope with the node), | |||||
| add this return to current module's outputs_collection, where | |||||
| key is this node's original onnx_name, and value is a list of | |||||
| tuple(successor_name, this node's opt_var_name) | |||||
| For returns from a submodule, it will check if the submodule has already collected returns, | |||||
| If not, do it and then continue the following procedures. | |||||
| Now we will check each element in submodule's outputs_collection. Note that we DO NOT check submodule's | |||||
| returns should be continued returning, but just return them. | |||||
| All these returns from submodules will be changes their original nodes output (a.k.a outputs provider) | |||||
| `opt_var_name` to submodules' `opt_var_name`. | |||||
| Finally, we match the outputs and inputs in the current module level. | |||||
| """ | |||||
| for (_, struct) in self.get_generate_order(): | |||||
| if isinstance(struct, NodeStruct): | |||||
| outputs_list = [] | |||||
| # add these successor nodes name to collection for future use | |||||
| for succ in struct.successor_nodes_names: | |||||
| outputs_list.append((succ, struct.ms_opt_var_name)) | |||||
| if outputs_list: | |||||
| self._append_to_outputs_collection(struct.onnx_name, outputs_list) | |||||
| if isinstance(struct, ModuleStruct): | |||||
| # Remove unnecessary returns, succ are all inside current | |||||
| if not struct.outputs_collection: | |||||
| struct.collect_returns() | |||||
| sub_outputs_collection = struct.outputs_collection | |||||
| # check each returns in sub | |||||
| for provider_name, outputs_list in sub_outputs_collection.items(): | |||||
| for output in outputs_list: | |||||
| (succ, _) = output # (succ, provider_opt_var_name) in output | |||||
| new_output = (succ, struct.ms_opt_var_name) | |||||
| self._append_to_outputs_collection(provider_name, new_output) | |||||
| self._match_sub_modules_inputs() | |||||
| def get_returned_opt_var_name(self) -> list: | |||||
| """Return a list of returned output var of this module.""" | |||||
| idx = 0 | |||||
| added_to_return = set() | |||||
| ret = [] | |||||
| for ext_successor_requested, opt_var_name_in_this_module in self.external_successor_local_returns_map.items(): | |||||
| if ext_successor_requested in added_to_return: | |||||
| continue | |||||
| ret.append((ext_successor_requested, opt_var_name_in_this_module, idx)) | |||||
| added_to_return.add(ext_successor_requested) | |||||
| return ret | |||||
| def reset_node_external_input_to_local(self, nd_struct): | |||||
| """ | |||||
| Reset node's input to module's construct args | |||||
| """ | |||||
| for prec_node_name in nd_struct.precursor_nodes_names_external: | |||||
| if prec_node_name in self.onnx_names: # prec node in current module's. | |||||
| continue | |||||
| if prec_node_name in self.construct_header_x.values(): | |||||
| # prec node assigned to construct header to passed in. | |||||
| local_x = get_dict_key_by_value(prec_node_name, self.construct_header_x) | |||||
| nd_struct.set_inputs_in_construct_header(local_x, prec_node_name) | |||||
| else: # Extra precursor nodes, raise error | |||||
| raise ValueError("Found external inputs of the Node but the module does not have it.") | |||||
| def reset_submodule_external_input_to_local(self, md_struct): | |||||
| """ | |||||
| Reset submodule's external input to current module. | |||||
| Args: | |||||
| md_struct (ModuleStruct): The submodule in the current module. | |||||
| """ | |||||
| # check submodule's input | |||||
| for _, submodule_precursor in md_struct.construct_header_x.items(): | |||||
| if submodule_precursor in self.onnx_names: # if internal, match with local nodes/submodules return | |||||
| # but do nothing here | |||||
| continue | |||||
| else: # if external, match with current module construct header x | |||||
| if submodule_precursor in self.construct_header_x.values(): | |||||
| local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x) | |||||
| md_struct.set_inputs_in_construct_header(local_x, submodule_precursor) | |||||
| else: # Extra precursor nodes, raise error | |||||
| raise ValueError("Found external inputs of the submodule but the module does not have it.") | |||||
| def register_node_output_to_module(self, nd_struct): | |||||
| """Register nodes outputs to this module's return.""" | |||||
| for succ_node_name in nd_struct.successor_nodes_names_external: | |||||
| self.external_successor_local_returns_map[succ_node_name] = nd_struct.ms_opt_var_name | |||||
| def register_submodule_output_to_module(self, md_struct): | |||||
| """Register submodule outputs to this module's return.""" | |||||
| submodule_returns = md_struct.get_returned_opt_var_name() | |||||
| submodule_opt_var_name = md_struct.ms_opt_var_name | |||||
| for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns: | |||||
| self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) | |||||
| # edit external succ 's inputs in parent module | |||||
| ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ) | |||||
| ext_node_parent = ext_node.parent_module_struct | |||||
| while ext_node_parent != self.parent_module_struct: | |||||
| ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name | |||||
| ext_node_parent = ext_node_parent.parent_module_struct | |||||
| # need find the prec_name? | |||||
| for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items(): | |||||
| if isinstance(opt_var_name, str): | |||||
| if opt_var_name == opt_var_name_in_this_module: | |||||
| ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) | |||||
| if isinstance(opt_var_name, tuple): | |||||
| if opt_var_name[0] == opt_var_name_in_this_module: | |||||
| ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) | |||||
| @@ -0,0 +1,423 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define the NodeStruct which stores all info. of a node.""" | |||||
| from collections import OrderedDict | |||||
| from .scope_utils import Scope | |||||
| from .args_translator import ArgsTranslation | |||||
| from ..common.code_fragment import CodeFragment | |||||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||||
| from ..common.global_context import GlobalContext | |||||
| from ..constant import InputType | |||||
| class NodeStruct: | |||||
| """ | |||||
| Define a node struct which stores all info. to generate statement. | |||||
| Args: | |||||
| args (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | |||||
| Note: | |||||
| You can pass as many args as possible and the Node Struct will update | |||||
| by arguments order. | |||||
| """ | |||||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||||
| def __init__(self, args): | |||||
| # define attributes here | |||||
| self._identifier = None | |||||
| self._fragment = None | |||||
| self._args_translator = None | |||||
| self._parent_module_struct = None | |||||
| self.topo_idx = None | |||||
| self.node_type = None | |||||
| self.onnx_name = None | |||||
| self.onnx_op = None | |||||
| self.graph_node_ref = None # Our defined GraphNode | |||||
| self.scope_name = None | |||||
| self.ms_var_name = None | |||||
| self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...) | |||||
| self.ms_op = None | |||||
| self.ready_to_generate = False | |||||
| self.ms_params = dict() # converted params from mapper | |||||
| self.ms_settings = dict() | |||||
| self.ms_weights = dict() | |||||
| self.ms_inputs = OrderedDict() | |||||
| self.scope = None # Defined Scope class | |||||
| self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use | |||||
| self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name | |||||
| self.matched_inputs = list() # Matched inputs will can be directly used by code line generation | |||||
| # initialize funcs. | |||||
| for arg in args: | |||||
| self.update(arg) | |||||
| def __repr__(self): | |||||
| return str({ | |||||
| "address": hex(id(self)), | |||||
| "idx": self.topo_idx, | |||||
| "identifier": self.identifier | |||||
| }) | |||||
| def ori_topo_idx(self): | |||||
| """Get the original topological index in the onnx graph.""" | |||||
| ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/') | |||||
| self.onnx_name = ori_name | |||||
| return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) | |||||
| def update_var_name(self, idx=None): | |||||
| """ | |||||
| Update the var_name of each node. | |||||
| Args: | |||||
| idx (int): The index of the node in this module. | |||||
| """ | |||||
| if idx is not None: | |||||
| self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx) | |||||
| elif self.topo_idx is not None: | |||||
| self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx) | |||||
| else: | |||||
| raise ValueError("Unable to update var name when topo_idx is None.") | |||||
| self.ms_opt_var_name = self.ms_var_name + '_opt' | |||||
| def _update_basics_from_gn(self, gn): | |||||
| """Update basic info from GraphNode.""" | |||||
| self.graph_node_ref = gn | |||||
| self.scope_name = gn.scope_name | |||||
| def _update_from_pytorch_gn(self, gn: PyTorchGraphNode): | |||||
| """Update basic info from PyTorchGraphNode.""" | |||||
| self.node_type = "PyTorchGraphNode" | |||||
| self._update_basics_from_gn(gn) | |||||
| def _update_from_onnx_gn(self, gn: OnnxGraphNode): | |||||
| """Update basic info from OnnxGraphNode.""" | |||||
| self.node_type = "OnnxGraphNode" | |||||
| self._update_basics_from_gn(gn) | |||||
| def _update_from_mapper(self, d): | |||||
| """Update info from mapper.""" | |||||
| if d.get('op_name'): | |||||
| self.ms_op = d.get('op_name') | |||||
| if d.get('params'): | |||||
| self.ms_params = d.get('params') | |||||
| if d.get('settings'): | |||||
| self.ms_settings = d.get('settings') | |||||
| if d.get('weights'): | |||||
| self.ms_weights = d.get('weights') | |||||
| def _update_from_fragment(self, frag: CodeFragment): | |||||
| """Update info from CodeFragment.""" | |||||
| self._fragment = frag | |||||
| if frag.operation: | |||||
| self.ms_op = frag.operation | |||||
| idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count | |||||
| self.update_var_name(idx=idx) | |||||
| def _set_scope_from_identifier(self): | |||||
| """Set the Node scope from identifier.""" | |||||
| parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) | |||||
| self.scope = Scope(parsed_scope) | |||||
| def init_args_translator(self, translated_args: list): | |||||
| """ | |||||
| Initialize the ArgsTranslator for each Node. | |||||
| Args: | |||||
| translated_args (list): The list of args should be translated to formal args. | |||||
| """ | |||||
| if not self._fragment: | |||||
| raise ValueError("Initialize argument translator failed.") | |||||
| if self._fragment.actual_args and translated_args: | |||||
| self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args) | |||||
| def check_if_generate_ready(self): | |||||
| """Check if the NodeStruct is able to generate code.""" | |||||
| # check essential params exists | |||||
| if all([self.identifier, | |||||
| self.node_type, | |||||
| self.scope_name, | |||||
| self.ms_var_name, | |||||
| self.ms_opt_var_name, | |||||
| self.ms_op]): | |||||
| self.ready_to_generate = True | |||||
| def update(self, arg, force_ready=False): | |||||
| """ | |||||
| Pass Node info. to generator NodeStruct. | |||||
| Args: | |||||
| arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | |||||
| force_ready (bool): Force this NodeStruct is ready to generate. | |||||
| """ | |||||
| if isinstance(arg, PyTorchGraphNode): | |||||
| self._update_from_pytorch_gn(arg) | |||||
| elif isinstance(arg, OnnxGraphNode): | |||||
| self._update_from_onnx_gn(arg) | |||||
| elif isinstance(arg, (dict, OrderedDict)): | |||||
| self._update_from_mapper(arg) | |||||
| elif isinstance(arg, CodeFragment): | |||||
| self._update_from_fragment(arg) | |||||
| else: | |||||
| raise TypeError("NodeStruct received an unsupported initializing argument.") | |||||
| if force_ready: | |||||
| self.ready_to_generate = True | |||||
| else: | |||||
| self.check_if_generate_ready() | |||||
| @property | |||||
| def identifier(self): | |||||
| """Return the identifier of the node.""" | |||||
| return self._identifier | |||||
| @identifier.setter | |||||
| def identifier(self, s): | |||||
| """ | |||||
| Set the Node identifier, and update the scope. | |||||
| Args: | |||||
| s (str): The node identifier string. | |||||
| """ | |||||
| self._identifier = s | |||||
| self._set_scope_from_identifier() | |||||
| self.topo_idx = self.ori_topo_idx() | |||||
| self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self | |||||
| @property | |||||
| def fragment(self): | |||||
| """Return the fragment of the node.""" | |||||
| return self._fragment | |||||
| @fragment.setter | |||||
| def fragment(self, frag): | |||||
| """ | |||||
| Set the Node fragment. | |||||
| Args: | |||||
| s (NodeFragment): The node identifier string. | |||||
| """ | |||||
| self._fragment = frag | |||||
| @property | |||||
| def graph_node(self): | |||||
| """Return the GraphNode reference.""" | |||||
| return self.graph_node_ref | |||||
| @graph_node.setter | |||||
| def graph_node(self, graphnode): | |||||
| """Set the GraphNode reference.""" | |||||
| self.graph_node_ref = graphnode | |||||
| @property | |||||
| def onnx_node(self): | |||||
| """Return the original onnx node reference.""" | |||||
| return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) | |||||
| @property | |||||
| def args_translator(self): | |||||
| """Return the args translator of this Node.""" | |||||
| return self._args_translator | |||||
| @property | |||||
| def precursor_nodes_names(self) -> list: | |||||
| """Return the names of precursor nodes.""" | |||||
| return self.graph_node_ref.precursor_nodes | |||||
| @property | |||||
| def precursor_nodes_structs(self) -> list: | |||||
| """Return the node struct instances of precursor nodes.""" | |||||
| ret = [] | |||||
| precursor_nodes_names = self.precursor_nodes_names | |||||
| for pre_node_name in precursor_nodes_names: | |||||
| nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||||
| ret.append(nd_struct) | |||||
| return ret | |||||
| @property | |||||
| def successor_nodes_names(self) -> list: | |||||
| """Return the names of successor nodes.""" | |||||
| return self.graph_node_ref.successor_nodes | |||||
| @property | |||||
| def successor_nodes_structs(self) -> list: | |||||
| """Return the node struct instances of successor nodes.""" | |||||
| ret = [] | |||||
| for pre_node_name in self.successor_nodes_names: | |||||
| nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||||
| ret.append(nd_struct) | |||||
| return ret | |||||
| @property | |||||
| def parent_module_struct(self): | |||||
| """Return the parent struct of this node.""" | |||||
| return self._parent_module_struct | |||||
| @parent_module_struct.setter | |||||
| def parent_module_struct(self, ref): | |||||
| self._parent_module_struct = ref | |||||
| # Code Generation funcs below | |||||
| def code_line_in_init(self): | |||||
| """Initialization line of code in module init block.""" | |||||
| left = "self.{}".format(self.ms_var_name) | |||||
| args_list = list() | |||||
| if self._args_translator is not None: | |||||
| args_list += self._args_translator.actual_args_to_str_list | |||||
| args_list += self._args_translator.formal_args_to_str_list | |||||
| else: | |||||
| actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args) | |||||
| args_list += actual_args_str | |||||
| right = f"{self.ms_op}({', '.join(args_list)})" | |||||
| return left, right | |||||
| def _get_correct_in_module_returns(self, prec_node, in_module_return): | |||||
| """ | |||||
| Find the correct precursor node name in return statement of its parent module. | |||||
| Args: | |||||
| prec_node (str): The onnx name of the precursor node given. | |||||
| in_module_return (list[tuple]): The list of outputs which contains parent module identifier | |||||
| and module opt_var_name. | |||||
| Return: | |||||
| str, correct opt_var_name to be passed in current node. | |||||
| """ | |||||
| found_return = False | |||||
| for ret in in_module_return: | |||||
| (md_identifier, input_name_to_use) = ret | |||||
| p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node) | |||||
| # recursive check the p node parent | |||||
| parent = p_node_struct | |||||
| while not found_return: | |||||
| parent = parent.parent_module_struct | |||||
| if parent is None: | |||||
| break | |||||
| if parent.identifier == md_identifier: | |||||
| return input_name_to_use | |||||
| return None | |||||
| def code_line_in_construct(self, inputs=None, in_module_returns=None): | |||||
| """Construct line of code in module construct block. """ | |||||
| left = self.ms_opt_var_name | |||||
| if inputs is None: | |||||
| inputs = [] | |||||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | |||||
| if self.inputs_in_construct_header.get(prec_node): | |||||
| inputs.append(self.inputs_in_construct_header.get(prec_node)) | |||||
| elif self._check_target_node_internal(prec_node): | |||||
| inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name) | |||||
| elif self.inputs_in_parent_module.get(prec_node): | |||||
| inputs.append(self.inputs_in_parent_module.get(prec_node)) | |||||
| elif in_module_returns and in_module_returns.get(self.onnx_name) \ | |||||
| and (not self._check_target_node_internal(prec_node)): | |||||
| inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name))) | |||||
| else: | |||||
| inputs.append("unk_{}_{}".format(idx, prec_node)) | |||||
| if self.matched_inputs: | |||||
| inputs = self.matched_inputs | |||||
| # Check original onnx node's input to ensure double inputs are not ignored | |||||
| original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name) | |||||
| new_inputs = [] | |||||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | |||||
| occurence = original_inputs.count(prec_node) | |||||
| for _ in range(occurence): | |||||
| new_inputs.append(inputs[idx]) | |||||
| inputs = new_inputs | |||||
| if isinstance(inputs, str): | |||||
| inputs = [inputs] | |||||
| if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value: | |||||
| inputs = [str(tuple(inputs)).replace("\'", "")] | |||||
| if self._fragment.code_setting and self._fragment.code_setting.op_extra_input: | |||||
| for _, val in self._fragment.code_setting.op_extra_input.items(): | |||||
| inputs.append(str(val)) | |||||
| if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor: | |||||
| inputs.append(f"self.{self.ms_var_name}_w") | |||||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | |||||
| return left, right | |||||
| def add_extra_tensor(self): | |||||
| """ Add extra tensor.""" | |||||
| left = "self.{}_w".format(self.ms_var_name) | |||||
| shape = self._fragment.code_setting.op_extra_tensor.shape | |||||
| right = f"Tensor(np.random.uniform(0, 1, {shape}), mindspore.float32)" | |||||
| return left, right | |||||
| # The following functions are specified for multiple in/out support. | |||||
| # and should be called only after generator._recursive_form_modules() | |||||
| def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): | |||||
| """ | |||||
| Mark the registered external inputs for code generation. | |||||
| Note: | |||||
| This function to be called by its parent (ModuleStruct). | |||||
| Args: | |||||
| header_x (str): The `x` in module construct header. | |||||
| onnx_precursor_node_name (str): The original onnx node name. | |||||
| """ | |||||
| if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None: | |||||
| raise ValueError("The input from {} has already registered. Check this node \ | |||||
| {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) | |||||
| self.inputs_in_construct_header[onnx_precursor_node_name] = header_x | |||||
| def _check_target_node_internal(self, name: str) -> bool: | |||||
| """ | |||||
| Check given node under the same scope. | |||||
| Args: | |||||
| name (str): Can accept both node identifier or original onnx node name. | |||||
| """ | |||||
| target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ | |||||
| or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) | |||||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||||
| return False | |||||
| if target_nd_struct is None: | |||||
| raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name)) | |||||
| return target_nd_struct.scope.path == self.scope.path | |||||
| @property | |||||
| def has_successor_node_external(self) -> bool: | |||||
| """Check if any successor_node is in external module.""" | |||||
| for name in self.successor_nodes_names: | |||||
| if not self._check_target_node_internal(name): | |||||
| return False | |||||
| return True | |||||
| @property | |||||
| def precursor_nodes_names_external(self) -> list: | |||||
| """Return a list of external precursor nodes names.""" | |||||
| return [name for name in self.precursor_nodes_names \ | |||||
| if not self._check_target_node_internal(name)] | |||||
| @property | |||||
| def successor_nodes_names_external(self) -> list: | |||||
| """Return a list of external successor nodes names.""" | |||||
| return [name for name in self.successor_nodes_names \ | |||||
| if not self._check_target_node_internal(name)] | |||||
| @@ -0,0 +1,157 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Define a scope class processing all operations related to scope and scope name.""" | |||||
| import re | |||||
| class Scope(): | |||||
| """Define scope related operations.""" | |||||
| def __init__(self, scope_str): | |||||
| scopes = scope_str.split('/') | |||||
| self.module_path = list() | |||||
| self.scope_list = scopes[:-1] | |||||
| self.head = self.scope_list[0] | |||||
| self.tail = self.scope_list[-1] | |||||
| self.initialization() | |||||
| def initialization(self): | |||||
| """Init scope class.""" | |||||
| self._update_module_path_from_scope_list() | |||||
| def _update_module_path_from_scope_list(self): | |||||
| """Update the module scope path from a list of scope.""" | |||||
| self.module_path = list() | |||||
| for scope in self.scope_list: | |||||
| if scope == 'Model': | |||||
| continue | |||||
| if 'Module' in scope: | |||||
| regex = r"Module(?P<num>\d+)_(?P<curr_level_unique_id>\d+)" | |||||
| match = re.match(regex, scope) | |||||
| if match: | |||||
| module_num = match.group('num') | |||||
| uid = match.group('curr_level_unique_id') | |||||
| self.module_path.append((int(module_num), int(uid))) | |||||
| @property | |||||
| def path(self): | |||||
| """Return module scope path.""" | |||||
| return self.module_path | |||||
| def set_path(self, ind, path_tuple: tuple): | |||||
| """ | |||||
| Set the module scope path. | |||||
| Args: | |||||
| ind (int): The index of the scope path to be set. | |||||
| path_tuple ((int, int)): The tuple of the scope path. | |||||
| """ | |||||
| self.module_path[ind] = path_tuple | |||||
| @property | |||||
| def to_str(self): | |||||
| """Return the full module scope as the string format.""" | |||||
| full_str_list = ["Model"] | |||||
| for (num, uid) in self.module_path: | |||||
| local = "Module{}_{}".format(num, uid) | |||||
| full_str_list.append(local) | |||||
| return "/".join(full_str_list) | |||||
| @property | |||||
| def depth(self): | |||||
| """Return the depth of the scope path.""" | |||||
| return len(self.path) | |||||
| @staticmethod | |||||
| def scope_to_module_name(path): | |||||
| """ | |||||
| Helper function to convert any scope path string to the full module scope. | |||||
| Args: | |||||
| path (str): path string like "[(5, 0), (3, 0)]" | |||||
| Returns: | |||||
| str, the full module scope with format like "Model/Module5_0/Module3_0/" | |||||
| """ | |||||
| scope_str_list = ["Model"] | |||||
| if isinstance(path, str): | |||||
| path = Scope.path_str_to_list(path) | |||||
| if isinstance(path, list): | |||||
| for (num, uid) in path: | |||||
| local = "Module{}_{}".format(num, uid) | |||||
| scope_str_list.append(local) | |||||
| return "/".join(scope_str_list) | |||||
| @staticmethod | |||||
| def parse_scope_from_node_identifier(node_identifier: str): | |||||
| """ | |||||
| Helper function to parse the scope string from node identifier. | |||||
| Args: | |||||
| node_identifier (str): The string of the node identifier. | |||||
| Returns: | |||||
| str, parsed scope string from node identifier. | |||||
| """ | |||||
| regex = r"(?P<scope>Model/.*)\$\S+\$" | |||||
| match = re.match(regex, node_identifier) | |||||
| if not match: | |||||
| return None | |||||
| return match.group('scope') | |||||
| @staticmethod | |||||
| def path_str_to_list(scope_path_str: str): | |||||
| """ | |||||
| Helper function to convert the scope path string back to list. | |||||
| Args: | |||||
| scope_path_str (str): The scope path string like "[(5, 0), (3, 0)]". | |||||
| Returns: | |||||
| list, a list of the scope path like [(5, 0), (3, 0)]. | |||||
| """ | |||||
| ret = [] | |||||
| tmp = scope_path_str.strip('[').strip(']') | |||||
| regex = r"\((?P<num>\d+), (?P<uid>\d+)\)" | |||||
| s_all = re.findall(regex, tmp) | |||||
| for (num, uid) in s_all: | |||||
| ret.append((int(num), int(uid))) | |||||
| return ret | |||||
| @staticmethod | |||||
| def get_parent_module_num_and_uid(path): | |||||
| """ | |||||
| Helper function to return its parent's scope tuple. | |||||
| Args: | |||||
| path (Union[str, list]): Module scope path string. e.g. "[(5, 0), (3, 0)]" | |||||
| Returns: | |||||
| tuple, parent's scope level. e.g. [(5, 0)] | |||||
| """ | |||||
| if isinstance(path, str): | |||||
| path = Scope.path_str_to_list(path) | |||||
| if isinstance(path, list): | |||||
| if len(path) == 1: # modules under the main module, (-1, -1) means main module. | |||||
| return (-1, -1) | |||||
| if len(path) > 1: # modules under another non-main module. Return parent's scope. | |||||
| parent = path[-2] | |||||
| return parent | |||||
| return None | |||||
| @@ -106,3 +106,51 @@ class GlobalVarNameMgr: | |||||
| global_var_namespace.add(new_name) | global_var_namespace.add(new_name) | ||||
| return new_name | return new_name | ||||
| class LocalVarNameMgr: | |||||
| """Local variable name mgr.""" | |||||
| def __init__(self): | |||||
| self.local_op_namespace = dict() | |||||
| self.local_var_namespace = set() | |||||
| @staticmethod | |||||
| def _get_name(name): | |||||
| """Deal with op name.""" | |||||
| if "::" in name: | |||||
| return name.split("::")[1] | |||||
| return name | |||||
| def get_name(self, op_type): | |||||
| """ | |||||
| Get module/variable name. | |||||
| If the module already existed, then add a suffix to it. | |||||
| conv1 onnx::conv | |||||
| Args: | |||||
| op_type (str): Operator type in onnx. | |||||
| Returns: | |||||
| str, module name. | |||||
| """ | |||||
| def _gen(t): | |||||
| t = t.lower() | |||||
| if t not in self.local_op_namespace: | |||||
| self.local_op_namespace[t] = START_IDX | |||||
| suffix = "" | |||||
| else: | |||||
| self.local_op_namespace[t] += 1 | |||||
| suffix = f"{self.local_op_namespace[t] - 1}" | |||||
| return f"{self._get_name(t)}{suffix}" | |||||
| new_name = _gen(op_type) | |||||
| while new_name in self.local_var_namespace: | |||||
| new_name = _gen(op_type) | |||||
| self.local_var_namespace.add(new_name) | |||||
| return new_name | |||||
| @@ -151,7 +151,7 @@ class OnnxGraph(Graph): | |||||
| input_shape (tuple): Input shape. | input_shape (tuple): Input shape. | ||||
| """ | """ | ||||
| input_node = InputNode(input_shape) | input_node = InputNode(input_shape) | ||||
| input_node_name = "{}InputNode" | |||||
| input_node_name = self._raw_input_nodes.replace(":0", "") | |||||
| for node_name, node in self._nodes_collection.items(): | for node_name, node in self._nodes_collection.items(): | ||||
| if node_name in self._input_nodes: | if node_name in self._input_nodes: | ||||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | ipt_nd_name = input_node_name.format(input_node.scope_name) | ||||
| @@ -23,6 +23,7 @@ import numpy as np | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from ..common.utils import fetch_output_from_onnx_model | from ..common.utils import fetch_output_from_onnx_model | ||||
| from ..common.global_context import GlobalContext | |||||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | ||||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ||||
| @@ -110,6 +111,7 @@ class OnnxTensor: | |||||
| self.to_nodes = [] | self.to_nodes = [] | ||||
| def to_array(self): | def to_array(self): | ||||
| """Convert the tensor value from binary to np array.""" | |||||
| onnx = import_module("onnx") | onnx = import_module("onnx") | ||||
| # Convert binary data to np.array | # Convert binary data to np.array | ||||
| if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): | if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): | ||||
| @@ -264,7 +266,7 @@ class OnnxDataLoader: | |||||
| self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes] | self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes] | ||||
| # args for init | # args for init | ||||
| self._is_infer_shape = infer_shape | self._is_infer_shape = infer_shape | ||||
| self._global_context = GlobalContext() | |||||
| # params parsed in init | # params parsed in init | ||||
| self.inferred_model = None | self.inferred_model = None | ||||
| @@ -375,12 +377,19 @@ class OnnxDataLoader: | |||||
| def _parse_nodes(self): | def _parse_nodes(self): | ||||
| """Parse each onnx nodes in the model.""" | """Parse each onnx nodes in the model.""" | ||||
| for node in self.nodes: | |||||
| nodes_topo_idx = [] | |||||
| for idx, node in enumerate(self.nodes): | |||||
| n = OnnxNode(node) | n = OnnxNode(node) | ||||
| self._nodes_dict[n.name] = n | self._nodes_dict[n.name] = n | ||||
| nodes_topo_idx.append((idx, n.name)) | |||||
| if len(node.output) > 1: | if len(node.output) > 1: | ||||
| raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") | raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") | ||||
| self.output_name_to_node_name[node.output[0]] = node.name | self.output_name_to_node_name[node.output[0]] = node.name | ||||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | |||||
| node_inputs = [i.replace(":0", "") for i in node.input] | |||||
| self._global_context.onnx_node_inputs[n.name] = node_inputs | |||||
| self._global_context.onnx_nodes_collection = self._nodes_dict | |||||
| self._global_context.onnx_nodes_topo_index = nodes_topo_idx | |||||
| def _parse_tensors(self): | def _parse_tensors(self): | ||||
| """Parse each onnx tensors in the model.""" | """Parse each onnx tensors in the model.""" | ||||
| @@ -388,6 +397,7 @@ class OnnxDataLoader: | |||||
| for tensor in tensors: | for tensor in tensors: | ||||
| t = OnnxTensor(tensor) | t = OnnxTensor(tensor) | ||||
| self.tensors_dict[t.name] = t | self.tensors_dict[t.name] = t | ||||
| self._global_context.onnx_tensors_collection = self.tensors_dict | |||||
| def _parse_node_output_shape(self): | def _parse_node_output_shape(self): | ||||
| """ | """ | ||||