From: @liangtianshu Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -40,8 +40,10 @@ class GlobalContext(metaclass=Singleton): | |||
| # Define data stored from onnx_utils | |||
| # Key as Onnx Name | |||
| self._onnx_nodes_collection = OrderedDict() | |||
| # key is topo_idx, value is onnx_node_name. | |||
| # key is topo_idx, value is onnx_node_name | |||
| self._onnx_nodes_topo_index = dict() | |||
| self.onnx_node_name_to_topo_idx = dict() | |||
| self.onnx_node_inputs = dict() | |||
| self._onnx_tensors_collection = dict() | |||
| # Define data stored from generator | |||
| @@ -50,7 +52,7 @@ class GlobalContext(metaclass=Singleton): | |||
| self.node_struct_adder_counter = 0 | |||
| # Define onnx_utils <---> generator mapping | |||
| self.node_struct_to_onnx_node_map = dict() | |||
| self.onnx_node_to_node_struct_map = dict() | |||
| self.onnx_node_name_to_node_struct_map = dict() | |||
| # Define Module pattern to customize name mapping | |||
| self.module_customized_name = dict() | |||
| @@ -59,6 +61,8 @@ class GlobalContext(metaclass=Singleton): | |||
| self.node_fragments = OrderedDict() | |||
| self.module_fragments = OrderedDict() | |||
| # Define Known module mapping | |||
| self.known_module_name = dict() | |||
| # Define Structs | |||
| # key is pattern_id, value is [ModuleStructs] | |||
| self.module_structs = dict() | |||
| @@ -83,7 +87,7 @@ class GlobalContext(metaclass=Singleton): | |||
| def get_identifier_from_onnx_node_name(self, node_name): | |||
| """Return the node identifier by Onnx Node name.""" | |||
| identifier = self.onnx_node_to_node_struct_map.get(node_name) | |||
| identifier = self.onnx_node_name_to_node_struct_map.get(node_name) | |||
| return identifier | |||
| @property | |||
| @@ -98,9 +102,7 @@ class GlobalContext(metaclass=Singleton): | |||
| @onnx_nodes_collection.setter | |||
| def onnx_nodes_collection(self, arg): | |||
| """ | |||
| Set the onnx nodes collection. | |||
| """ | |||
| """Set the onnx nodes collection.""" | |||
| if isinstance(arg, OrderedDict): | |||
| self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader | |||
| else: | |||
| @@ -108,11 +110,18 @@ class GlobalContext(metaclass=Singleton): | |||
| @property | |||
| def onnx_nodes_topo_index(self) -> dict: | |||
| "Return the onnx nodes and topological index." | |||
| """Return the onnx nodes and topological index.""" | |||
| return self._onnx_nodes_topo_index | |||
| @onnx_nodes_topo_index.setter | |||
| def onnx_nodes_topo_index(self, index_list): | |||
| """ | |||
| Set the onnx nodes and topological index. | |||
| Args: | |||
| index_list (list[tuple[int, str]]): a list of tuple contains the topological index and onnx node name. | |||
| """ | |||
| if not isinstance(index_list, list): | |||
| raise TypeError("The argument index_list must be a list of tuple (index, onnx_node_name).") | |||
| if not isinstance(index_list[0], tuple): | |||
| @@ -122,10 +131,17 @@ class GlobalContext(metaclass=Singleton): | |||
| @property | |||
| def onnx_tensors_collection(self): | |||
| """Return the onnx tensors collection.""" | |||
| return self.onnx_tensors_collection | |||
| @onnx_tensors_collection.setter | |||
| def onnx_tensors_collection(self, arg): | |||
| """ | |||
| Set the onnx tensors collection by OnnxDataLoader. | |||
| Args: | |||
| arg (dict): The OnnxDataLoader generated tensors_dict. | |||
| """ | |||
| if isinstance(arg, dict): | |||
| self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader | |||
| else: | |||
| @@ -133,6 +149,12 @@ class GlobalContext(metaclass=Singleton): | |||
| @property | |||
| def latest_node_struct_count(self): | |||
| """ | |||
| Return the latest node struct count. | |||
| Note: | |||
| The counter will increase by 1 to tracking the number of nodes added. | |||
| """ | |||
| ret = self.node_struct_adder_counter | |||
| self.node_struct_adder_counter += 1 | |||
| return ret | |||
| @@ -184,18 +206,29 @@ class GlobalContext(metaclass=Singleton): | |||
| self.module_customized_name[pattern_id] = customized_name | |||
| def get_node_fragment(self, identifier): | |||
| """Return the node fragment by identifier.""" | |||
| return self.node_fragments.get(identifier) | |||
| def add_code_fragment(self, identifier, frag): | |||
| """Add the node fragment by identifier.""" | |||
| self.node_fragments[identifier] = frag | |||
| def get_module_fragment(self, identifier): | |||
| """Return the module fragment by identifier.""" | |||
| return self.module_fragments.get(identifier) | |||
| def add_module_fragment(self, identifier, frag): | |||
| """Add the module fragment by identifier.""" | |||
| self.module_fragments[identifier] = frag | |||
| def add_module_struct(self, pattern_id, module_struct): | |||
| """ | |||
| Add module struct by its pattern_id. | |||
| Args: | |||
| pattern_id (int): The pattern which represents the structure of the module. | |||
| module_struct (ModuleStruct): The ModuleStruct instance. | |||
| """ | |||
| if self.module_structs.get(pattern_id) is None: | |||
| self.module_structs[pattern_id] = [module_struct] | |||
| else: | |||
| @@ -135,3 +135,18 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str, | |||
| if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): | |||
| return False | |||
| return True | |||
| def get_dict_key_by_value(val, dic): | |||
| """ | |||
| Return the first appeared key of a dictionay by given value. | |||
| Args: | |||
| val (Any): Value of the key. | |||
| dic (dict): Dictionary to be checked. | |||
| Returns: | |||
| Any, key of the given value. | |||
| """ | |||
| for d_key, d_val in dic.items(): | |||
| if d_val == val: | |||
| return d_key | |||
| return None | |||
| @@ -0,0 +1,111 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Generator module.""" | |||
| __all__ = ["batch_add_nodes"] | |||
| import re | |||
| import copy | |||
| from .generator import Generator, CodeStruct | |||
| from ..common.code_fragment import CodeFragment | |||
| def _tf_model_node_name_reformat(node, node_name): | |||
| """ | |||
| Rename the node name by combining scope name and its original name. | |||
| Args: | |||
| node (OnnxGraphNode): OnnxGraphNode instance. | |||
| node_name (str): node name saved in Graph. | |||
| Returns: | |||
| str, re-formatted node name. | |||
| """ | |||
| scope_name = node.scope_name | |||
| new_name = None | |||
| regex = r"(?P<parent>.+/)(?P<op>\w+)" | |||
| match = re.match(regex, scope_name) | |||
| parent = match.group("parent") | |||
| node_name = '$' + node_name.replace('/', '::') + '$' | |||
| if scope_name: | |||
| new_name = parent + node_name | |||
| return new_name | |||
| return node_name | |||
| def batch_add_nodes(graph_obj, mapper) -> Generator: | |||
| """ | |||
| Add nodes to Generator in batch mode. | |||
| Args: | |||
| graph_obj (Graph): Graph obj. | |||
| mapper (Mapper): Mapper of third party framework and MindSpore. | |||
| """ | |||
| generator_inst = Generator() | |||
| for node_name in graph_obj.nodes_in_topological_order: | |||
| node_inst = graph_obj.get_node(node_name) | |||
| node_input = graph_obj.get_input_shape(node_name) | |||
| node_output = graph_obj.get_output_shape(node_name) | |||
| if not node_input: | |||
| raise ValueError("Unable to get the node's inputs from Graph object.") | |||
| node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name) | |||
| node_name = node_name_with_scope | |||
| node_inst.add_input_and_output_shape(node_input, node_output) | |||
| op_name, params, settings, weights = _convert_params(node_inst, mapper) | |||
| generator_inst.add_node( | |||
| node_name, | |||
| node_instance=node_inst, | |||
| node_fragment=CodeFragment(op_name, params, | |||
| settings, | |||
| node_inst.input_shape, | |||
| node_inst.output_shape, | |||
| weights) | |||
| ) | |||
| return generator_inst | |||
| def _convert_params(node, mapper): | |||
| """ | |||
| Call mapper to convert node's params from ONNX to MindSpore. | |||
| Args: | |||
| node (GraphNode): Our defined GraphNode instance. | |||
| mapper (Mapper): The mapper instance which indicating conversion method. | |||
| Returns: | |||
| str, op name in MindSpore | |||
| dict, MindSpore parameters | |||
| dict, MindSpore settings | |||
| dict, weights of the node | |||
| """ | |||
| params = copy.deepcopy(node.node_params) | |||
| params.update({"input_shape": node.input_shape, | |||
| "output_shape": node.output_shape}) | |||
| op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name, | |||
| params=params, | |||
| weights=node.weight) | |||
| if "input_shape" in ms_params: | |||
| ms_params.pop("input_shape") | |||
| if "output_shape" in ms_params: | |||
| ms_params.pop("output_shape") | |||
| if op_in_ms: | |||
| return op_in_ms, ms_params, ms_settings, weights | |||
| return node.op_name, node.node_params, dict(), dict() | |||
| @@ -0,0 +1,248 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define arguments translation related operations for params name changing.""" | |||
| import copy | |||
| class ArgsTranslation: | |||
| """Define a universal arguments translation manager.""" | |||
| def __init__(self, original_actual_args: dict, var_name: str, translated_args: list): | |||
| """ | |||
| Init the ArgsTranslation. | |||
| Args: | |||
| original_actual_args (dict): The full original args from fragments. | |||
| var_name (str): The var name for current Node / Module. | |||
| translated_args (list): The list of args need to translate to formal args. | |||
| """ | |||
| if not var_name: | |||
| raise ValueError("Initialize ArgsTranslation requires the var_name.") | |||
| self.var_name = var_name | |||
| self.actual_args = dict() # e.g. key is 'num_features', value is 2048 | |||
| self.formal_args = dict() # e.g. key is 'num_features', value is 'var_name_num_features'} | |||
| self.formal_args_values = dict() # e.g. key 'var_name_num_features', value 2048. Value use for up-level | |||
| self.actual_args_backup = dict() # backup actual args before translation | |||
| self.actual_args_to_str_list = list() | |||
| self.formal_args_to_str_list = list() | |||
| self.formal_args_values_to_str_list = list() | |||
| self.actual_args_backup_to_str_list = list() | |||
| if all([original_actual_args, translated_args]): | |||
| # MUST ensure only one var_name in a scope. | |||
| for arg_name, arg_value in original_actual_args.items(): | |||
| if arg_name in translated_args: | |||
| formal_arg_name = '_'.join([var_name, arg_name]) | |||
| self.formal_args[arg_name] = formal_arg_name | |||
| self.formal_args_values[formal_arg_name] = arg_value | |||
| else: | |||
| self.actual_args[arg_name] = arg_value | |||
| self.make_str() | |||
| @staticmethod | |||
| def dict_data_to_args_str_list(any_dict): | |||
| """ | |||
| Output a list of string of dict data by "key=value" format. | |||
| Args: | |||
| any_dict (dict): Any dictionary | |||
| Returns: | |||
| list, the list of strings showing dictionary as "key=value" format. | |||
| """ | |||
| ret = [] | |||
| for key, val in any_dict.items(): | |||
| ret.append('='.join([key, str(val)])) | |||
| return ret | |||
| def make_str(self): | |||
| """Make string used in code generation.""" | |||
| self.actual_args_to_str_list = list() | |||
| self.formal_args_to_str_list = list() | |||
| self.formal_args_values_to_str_list = list() | |||
| self.actual_args_backup_to_str_list = list() | |||
| if self.actual_args: | |||
| self.actual_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args) | |||
| if self.formal_args: | |||
| self.formal_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args) | |||
| if self.formal_args_values: | |||
| self.formal_args_values_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args_values) | |||
| if self.actual_args_backup: | |||
| self.actual_args_backup_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args_backup) | |||
| def __repr__(self): | |||
| return str({ | |||
| "address": hex(id(self)), | |||
| "var_name": self.var_name, | |||
| "actual_args": self.actual_args, | |||
| "actual_bak": self.actual_args_backup, | |||
| "formal_args": self.formal_args, | |||
| "formal_val ": self.formal_args_values | |||
| }) | |||
| def set_actual_args_backup(self): | |||
| """Backup the actual args before translating to formal.""" | |||
| self.actual_args_backup = copy.deepcopy(self.actual_args) | |||
| def deepcopy(self): | |||
| """Return a deepcopy of self.""" | |||
| return copy.deepcopy(self) | |||
| def make_actual_arg_to_formal(self, actual_arg_name): | |||
| """ | |||
| Make the actual arg to a formal arg. | |||
| Args: | |||
| actual_arg_name (str): The name of the actual arg to be formal. | |||
| """ | |||
| val = self.actual_args.get(actual_arg_name) | |||
| if val is None: | |||
| raise ValueError("Unable to convert the actual arg to formal due to missing arg.") | |||
| formal_arg_name = ('_').join([self.var_name, actual_arg_name]) | |||
| self.actual_args.pop(actual_arg_name) | |||
| self.formal_args[actual_arg_name] = formal_arg_name | |||
| self.formal_args_values[formal_arg_name] = val | |||
| self.make_str() | |||
| def _update_dict_for_upper_level(self, d, upper_level_var_name): | |||
| """Add upper level var name to key name of selected dictionary.""" | |||
| new_d = dict() | |||
| for arg_name, val in d.items(): | |||
| new_arg_name = '_'.join([upper_level_var_name, arg_name]) # e.g. conv2d_0_in_channels_Module_3_0 | |||
| new_d[new_arg_name] = val | |||
| return new_d | |||
| def escalate_to_upper_level(self, upper_level_var_name): | |||
| """ | |||
| Escalate this args translator for upper level module use. | |||
| Note: | |||
| You MUST deepcopy this translator first to avoid editing values in the original translator. | |||
| """ | |||
| # update all args name by adding upper_level_var_name. | |||
| tmp_actual_args = self._update_dict_for_upper_level(self.actual_args, upper_level_var_name) | |||
| tmp_formal_args = self._update_dict_for_upper_level(self.formal_args, upper_level_var_name) | |||
| tmp_formal_args_values = self._update_dict_for_upper_level(self.formal_args_values, upper_level_var_name) | |||
| self.actual_args = tmp_actual_args | |||
| self.formal_args = tmp_formal_args | |||
| self.formal_args_values = tmp_formal_args_values | |||
| self.make_str() | |||
| def make_formal_args_back_to_actual(self, formal_arg): | |||
| """ | |||
| Move the formal arg back to actual arg. | |||
| Note: | |||
| This does not reset the formal arg name back, | |||
| Only used for module init statement. | |||
| Args: | |||
| formal_arg (str): formal argument name. | |||
| """ | |||
| if isinstance(formal_arg, str): | |||
| val = self.formal_args_values.pop(formal_arg) | |||
| self.actual_args[formal_arg] = val | |||
| if isinstance(formal_arg, list): | |||
| for arg in formal_arg: | |||
| val = self.formal_args_values.pop(arg) | |||
| self.actual_args[formal_arg] = val | |||
| self.make_str() | |||
| def take_formal_args_from_args_translator(self, args_translator, escalate_sub=False): | |||
| """ | |||
| Add submodule's or node's args translator to this translator. | |||
| Args: | |||
| args_translator (ArgsTranslation): submodule's or node's args translator. | |||
| """ | |||
| if escalate_sub: | |||
| sub_args_translator = args_translator.deepcopy() | |||
| sub_args_translator.escalate_to_upper_level(self.var_name) | |||
| else: | |||
| sub_args_translator = args_translator | |||
| original_actual_args = sub_args_translator.formal_args_values | |||
| self.actual_args.update(original_actual_args) | |||
| self.make_str() | |||
| def take_formal_args_from_nodes_and_submodules(self, args_translators: list, escalate_sub=False): | |||
| """ | |||
| Take all formal args from nodes and submodules from passed in args_translators. | |||
| Args: | |||
| args_translators (ArgsTranslation): A list of ArgsTranslation instances. | |||
| escalate_sub (Bool): should escalate all formal args. Default: False | |||
| """ | |||
| for arg_t in args_translators: | |||
| self.take_formal_args_from_args_translator(arg_t, escalate_sub=escalate_sub) | |||
| class ArgsTranslationHelper: | |||
| """Define operations related to ArgsTranslation instances.""" | |||
| @staticmethod | |||
| def find_formal_args_in_modules(args_translators): | |||
| """ | |||
| Find formal args among multiple args translators. | |||
| Args: | |||
| args_translators(list[ArgsTranslation]): a list of args translator to be checked. | |||
| Returns: | |||
| list, name of args to be formal. | |||
| """ | |||
| if len(args_translators) < 2: | |||
| # only one args_translator provided, no formal args. | |||
| return None | |||
| ret = [] | |||
| base_args_t = args_translators[0] | |||
| for arg_name, arg_val in base_args_t.actual_args.items(): | |||
| for args_t in args_translators[1:]: | |||
| val = args_t.actual_args.get(arg_name) | |||
| if val is None: | |||
| raise ValueError("Unable to find the given args as the args translator is not consistent.") | |||
| if val != arg_val: # val not equal | |||
| ret.append(arg_name) | |||
| break | |||
| return ret | |||
| @staticmethod | |||
| def change_args_to_formal_for_all_translators(args_name, args_translators): | |||
| """ | |||
| Change args to formal for all translators provided. | |||
| Args: | |||
| args_name (str): The name of args to be changing. | |||
| args_translators (ArgsTranslation): The args to be changed in args translators. | |||
| """ | |||
| if isinstance(args_name, str): | |||
| args_name = [args_name] | |||
| if isinstance(args_translators, ArgsTranslation): | |||
| args_translators = [args_translators] | |||
| for arg in args_name: | |||
| for args_t in args_translators: | |||
| args_t.set_actual_args_backup() | |||
| args_t.make_actual_arg_to_formal(arg) | |||
| @@ -0,0 +1,630 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Main Generator module.""" | |||
| import copy | |||
| from collections import OrderedDict | |||
| from .scope_utils import Scope | |||
| from .node_struct import NodeStruct | |||
| from .module_struct import ModuleStruct | |||
| from .args_translator import ArgsTranslationHelper | |||
| from ..common.global_context import GlobalContext | |||
| from ..hierarchical_tree.name_mgr import GlobalVarNameMgr | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT | |||
| class Singleton(type): | |||
| """Metaclass to make the generator to be single instance.""" | |||
| _instances = {} | |||
| def __call__(cls, *args, **kwargs): | |||
| if cls not in cls._instances: | |||
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||
| return cls._instances[cls] | |||
| class CodeStruct: | |||
| """ | |||
| Define the Code template for each module generated in the final output. | |||
| Each module has only one CodeStruct to its pattern. | |||
| """ | |||
| GLOBAL_CONTEXT = GlobalContext() | |||
| NOT_IN_SCOPE_OPT = dict() | |||
| def __init__(self, struct, repeated_submodules=None): | |||
| """Initialize the CodeStruct.""" | |||
| self.output_order = None # output order | |||
| self.input = None # opt_var_name for prev. node | |||
| self.extra_input = list() # extra_input(s) at construct method args | |||
| self.output = None # opt_var_name for next node | |||
| self.extra_output = list() # extra_output(s) | |||
| self.extra_comment = None # comments for this code line / block. | |||
| self.code_line_list = list() # list of code line, a item is a line. | |||
| self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module | |||
| self.formal_args_collections = None | |||
| if isinstance(struct, NodeStruct): | |||
| self.output_order = struct.topo_idx | |||
| if isinstance(struct, ModuleStruct): | |||
| self.output_order = struct.head_nd_struct_index | |||
| self._generate_from_module_struct(struct, repeated_submodules) | |||
| def _add_line(self, s): | |||
| """Add line of code.""" | |||
| self.code_line_list.append(s) | |||
| @property | |||
| def new_line(self): | |||
| """Return last generated line.""" | |||
| try: | |||
| return self.code_line_list[-1] | |||
| except IndexError: | |||
| return "" | |||
| @new_line.setter | |||
| def new_line(self, s): | |||
| """Make a new line.""" | |||
| self._add_line(s) | |||
| def _generate_from_module_struct(self, md_struct, repeated_submodules): | |||
| """ | |||
| Generate the code of current Module Struct, collecting data from submodules. | |||
| Args: | |||
| md_struct (ModuleStruct): The ModuleStruct which generates codes. | |||
| repeated_submodules (dict): The dict contains all submodules which use repeatedly. | |||
| Can get this dict from generator. | |||
| """ | |||
| # Define tmp var for code generation. | |||
| opt_var_name_records = dict() # now only support multiple outputs within same scope. | |||
| return_value_records = dict() # save returned values for successor nodes/modules use. | |||
| # Define Module header code line below | |||
| if md_struct.pattern_id != -1: | |||
| class_name = f"Module{md_struct.pattern_id}" | |||
| else: | |||
| class_name = "Model" | |||
| # define a class declaration | |||
| self.new_line = f"class {class_name}(nn.Cell):" | |||
| # Get all formal args from nodes | |||
| module_def_args = ['self'] | |||
| if md_struct.args_translator.actual_args: | |||
| for actual in md_struct.args_translator.actual_args.keys(): | |||
| module_def_args.append(actual) | |||
| if md_struct.args_translator.formal_args: | |||
| for formal in md_struct.args_translator.formal_args.keys(): | |||
| module_def_args.append(formal) | |||
| # Collect extra inputs and outputs | |||
| # For code line in init & construct blocks | |||
| init_lines = list() | |||
| cons_lines = list() | |||
| for (_, struct) in md_struct.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): # Generate code line for Node. | |||
| code_line_init = struct.code_line_in_init() | |||
| code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records) | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | |||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | |||
| # add extra tensor | |||
| if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor: | |||
| code_extra_tensor = struct.add_extra_tensor() | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}") | |||
| # record opt_var_name for succ nodes input in same scope. | |||
| target_onnx_name = struct.graph_node_ref.successor_nodes | |||
| for name in target_onnx_name: | |||
| if opt_var_name_records.get(name): | |||
| opt_var_name_records.get(name).append(code_line_construct[0]) | |||
| else: | |||
| opt_var_name_records[name] = [code_line_construct[0]] | |||
| if struct.successor_nodes_names_external: | |||
| for ret_user in struct.successor_nodes_names_external: | |||
| if return_value_records.get(ret_user) is not None: | |||
| return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0])) | |||
| else: | |||
| return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])] | |||
| elif isinstance(struct, ModuleStruct): | |||
| # check if this instance generated CodeStruct | |||
| if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None: | |||
| CodeStruct(struct, repeated_submodules) | |||
| code_line_init = struct.code_line_in_init() | |||
| code_line_construct = struct.code_line_in_construct(inputs=struct.matched_inputs) | |||
| init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") | |||
| cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") | |||
| # record opt_var_name for succ nodes input in same scope. | |||
| target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes | |||
| for name in target_onnx_name: | |||
| if opt_var_name_records.get(name): | |||
| opt_var_name_records.get(name).append(code_line_construct[0]) | |||
| else: | |||
| opt_var_name_records[name] = [code_line_construct[0]] | |||
| # record submodule's local return map for following nodes / submodules use | |||
| if struct.external_successor_local_returns_map: | |||
| for ret_user, _ in struct.external_successor_local_returns_map.items(): | |||
| if return_value_records.get(ret_user) is not None: | |||
| # mulitple returns of a node may need modifiy the index. | |||
| return_value_records[ret_user].append((struct.identifier, code_line_construct[0])) | |||
| else: | |||
| return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])] | |||
| else: | |||
| raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.") | |||
| # define header of init block | |||
| self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):" | |||
| self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()" | |||
| # add init code lines to code line list. | |||
| self.code_line_list += init_lines | |||
| self.new_line = f"{NEW_LINE * 2}" | |||
| # define header of construct block | |||
| inputs = ['self'] + list(md_struct.construct_header_x.keys()) | |||
| self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):" | |||
| # add construct code lines to code line list. | |||
| self.code_line_list += cons_lines | |||
| # define returns | |||
| returns = [] | |||
| if md_struct.external_successor_local_returns_map: | |||
| ret = list(md_struct.external_successor_local_returns_map.values()) | |||
| for r in ret: | |||
| if isinstance(r, tuple): # results return with index nth output | |||
| returns.append(r[0]) | |||
| else: | |||
| returns.append(r) | |||
| returns = list(set(returns)) | |||
| else: | |||
| returns = [code_line_construct[0]] | |||
| self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" | |||
| self.new_line = f"{NEW_LINE * 2}" | |||
| self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self | |||
| class Generator(metaclass=Singleton): | |||
| """The generator controls all routines of code generation.""" | |||
| def __init__(self): | |||
| """Init the generator.""" | |||
| # define basic attributes | |||
| self.framework = None | |||
| # define MUST have params | |||
| self._node_struct_collections = OrderedDict() | |||
| self._module_struct_collections = OrderedDict() | |||
| self._module_depth_max = 0 | |||
| self._module_depth_min = 0 | |||
| # define intermediate var. during conversion | |||
| self._module_map = OrderedDict() | |||
| self._global_context = GlobalContext() | |||
| self._global_context.node_struct_collections = self._node_struct_collections | |||
| self._repeated_submodules = set() | |||
| def _form_bottom_submodule(self): | |||
| """Form the basic submodules, which only contains nodes.""" | |||
| # Form module map | |||
| curr_scope_path = None | |||
| nd_struct_list_in_submodule = [] | |||
| for nd_struct in self.node_structs.values(): | |||
| idx = nd_struct.topo_idx | |||
| if curr_scope_path is None: | |||
| curr_scope_path = nd_struct.scope.path | |||
| nd_struct_list_in_submodule.append((idx, nd_struct)) | |||
| elif curr_scope_path == nd_struct.scope.path: | |||
| nd_struct_list_in_submodule.append((idx, nd_struct)) | |||
| else: # curr_scope_path changed | |||
| # save this submodule | |||
| if self._module_map.get(str(curr_scope_path)) is not None: | |||
| self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule | |||
| else: | |||
| self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule | |||
| # create a new one | |||
| curr_scope_path = nd_struct.scope.path | |||
| nd_struct_list_in_submodule = [(idx, nd_struct)] | |||
| # save last submodule | |||
| if self._module_map.get(str(curr_scope_path)) is not None: | |||
| self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule | |||
| else: | |||
| self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule | |||
| # Form bottom modules' ModuleStruct | |||
| for scope_path_str, nd_struct_list in self._module_map.items(): | |||
| self._module_struct_collections[scope_path_str] = ModuleStruct(nd_struct_list) | |||
| def _list_repeated_submodules(self) -> OrderedDict: | |||
| """ | |||
| Return the repeated submodules by its depth and num. | |||
| For example, "Model/Module3_3" will return {1:(3)} | |||
| Return: | |||
| OrderedDict, a dict contains collections of repeated submodules. | |||
| """ | |||
| ret = OrderedDict() | |||
| for depth_control in range(self._module_depth_max, 0, -1): | |||
| repeated_submodules_at_this_depth = set() | |||
| for scope_path in self._module_map.keys(): | |||
| path = Scope.path_str_to_list(scope_path) | |||
| if len(path) < depth_control: | |||
| continue | |||
| else: # depth control within path length | |||
| module_num = path[depth_control - 1][0] | |||
| repeated_submodules_at_this_depth.add(module_num) | |||
| ret[depth_control] = repeated_submodules_at_this_depth | |||
| self._repeated_submodules = ret | |||
| return ret | |||
| def _compare_with_base_parameters(self, nd_struct_list): | |||
| """ | |||
| Compare the parameter to check if it should be a formal args. | |||
| Args: | |||
| nd_struct_list (list): A list of NodeStructs which contains | |||
| all same nodes in repeated submodules. | |||
| Return: | |||
| set, a set of all formal args in this node. | |||
| """ | |||
| formal_args = set() | |||
| if len(nd_struct_list) < 2: | |||
| return formal_args | |||
| (_, base_nd_struct) = nd_struct_list[0] | |||
| for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param | |||
| for (_, nd_struct) in nd_struct_list[1:]: | |||
| compared_value = nd_struct.fragment.actual_args.get(base_parameter) | |||
| if compared_value == base_value: | |||
| continue | |||
| else: | |||
| formal_args.add(base_parameter) | |||
| break | |||
| return formal_args | |||
| def _list_formal_parameters_in_a_module(self, module_filter_return): | |||
| """ | |||
| Find all formal args / params from nodes in a module. | |||
| Args: | |||
| module_filter_return (dict): The filtered results from the module_map_filter. | |||
| Return: | |||
| list, a list of sets or None indicates all formal args of each node in the module in order. | |||
| """ | |||
| formal_params_list = list() | |||
| transposed = [list(e) for e in zip(*module_filter_return)] | |||
| for operation in transposed: | |||
| formal_parameters = self._compare_with_base_parameters(operation) | |||
| if formal_parameters: | |||
| formal_params_list.append(formal_parameters) | |||
| else: | |||
| formal_params_list.append(None) | |||
| return formal_params_list | |||
| def _list_formal_parameters(self, repeated_submodules) -> dict: | |||
| """ | |||
| Return a list of formal parameters in each submodule. | |||
| Args: | |||
| repeated_submodules (dict): A dict which contains repeated submodules, | |||
| acquire this dict from _list_repeated_submodules() | |||
| Return: | |||
| OrderedDict, a dict with each submodule's formal args. | |||
| Example: | |||
| A return for ResNet50 could be: | |||
| {0: # submoodule 0 | |||
| [set('stride', 'in_channels', 'out_channels'), # args of the first node to be set as formal | |||
| set('num_features'), # args of the second node to be set as formal | |||
| None, # args of third node to be set as formal, which does not have | |||
| set('in_channels', 'out_channels'), | |||
| set('num_features'), | |||
| None | |||
| ]}, | |||
| {3: # submodule 3 | |||
| [...], | |||
| {5: # submodule 5 | |||
| []} # empty returns means no nodes or it's a parent module of submodules. | |||
| } | |||
| """ | |||
| formal_args_in_each_submodule = OrderedDict() | |||
| checked_module = set() | |||
| # filter module_map by submodule_num (without depth) | |||
| for _, module_nums in repeated_submodules.items(): | |||
| for module_num in module_nums: | |||
| if module_num in checked_module: # module already checked | |||
| continue | |||
| else: | |||
| checked_module.add(module_num) | |||
| map_filtered = self.module_map_filter(module_num=module_num) | |||
| formal_args_in_this_module = self._list_formal_parameters_in_a_module(map_filtered) | |||
| formal_args_in_each_submodule[module_num] = formal_args_in_this_module | |||
| return formal_args_in_each_submodule | |||
| def _add_submodule_to_parent(self): | |||
| """ | |||
| Recursively add all submodule to its parent module until Main module. | |||
| Note: | |||
| This function deepcopy the first node of the submodule, and reset its params as parent module. | |||
| """ | |||
| depth = self._module_depth_max | |||
| while depth > 0: | |||
| for (scope_path_str, md_struct) in self.module_structs.copy().items(): | |||
| if scope_path_str == '[]': | |||
| continue # is main module, skip | |||
| if md_struct.scope_depth != depth: | |||
| continue # skip all submodules not at current depth | |||
| md_struct_scope = copy.deepcopy(md_struct.identifier) | |||
| md_struct_scope.pop() | |||
| parent_scope = md_struct_scope | |||
| # 1. check if this module has parent module | |||
| parent_md_struct = self.module_structs.get(str(parent_scope)) | |||
| if parent_md_struct is not None: | |||
| # 1A. has parent, directly add md_struct to its parent ModuleStruct. | |||
| parent_md_struct.add_submodule(md_struct) | |||
| self.module_structs[str(parent_scope)] = parent_md_struct | |||
| else: | |||
| # 1B. not has parent, generate a new ModuleStruct | |||
| parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module | |||
| # rewrite parent md struct | |||
| parent_md_struct.reset_as_parent() | |||
| parent_md_struct.add_submodule(md_struct) | |||
| self.module_structs[str(parent_scope)] = parent_md_struct | |||
| sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections | |||
| self._global_context.add_module_struct(sub.pattern_id, sub) | |||
| depth -= 1 | |||
| def _recursive_form_module(self): | |||
| """Main routine in generator to build modules from bottom to top.""" | |||
| # 1. List repeated submodules | |||
| repeated_submodules = self._list_repeated_submodules() | |||
| # 2. List reused parameters | |||
| formal_parameters = self._list_formal_parameters(repeated_submodules) | |||
| # 3. Build base subdmodules and set in/ext params translation | |||
| for module_struct in self.module_structs.values(): | |||
| if module_struct.pattern_id == -1: # is main module | |||
| continue | |||
| formal_args = formal_parameters.get(module_struct.pattern_id) | |||
| module_struct.update_args_translation_list(formal_args) | |||
| # 4. Form parent modules | |||
| md_collection_len = len(self.module_structs.keys()) | |||
| len_changes = True | |||
| while len_changes: | |||
| self._add_submodule_to_parent() | |||
| new_len = len(self.module_structs.keys()) | |||
| if md_collection_len != new_len: | |||
| md_collection_len = new_len | |||
| else: | |||
| len_changes = False | |||
| # 5. Update all translated args from module map | |||
| self._update_all_modules_args_translator() | |||
| # 6. Update all nodes and moudles input/output | |||
| self.module_structs.get('[]').allocate_construct_header_x() | |||
| self.module_structs.get('[]').collect_returns() | |||
| def _update_all_modules_args_translator(self): | |||
| """Update all modules' args translators.""" | |||
| done_submodule = set() | |||
| for depth in range(self._module_depth_max, 0, -1): | |||
| # check modules from bottom to top | |||
| repeated_submodules = copy.deepcopy(self._repeated_submodules) | |||
| repeated_modules = repeated_submodules.get(depth) | |||
| if depth is None: | |||
| continue | |||
| for pattern_id in repeated_modules: | |||
| if pattern_id in done_submodule: | |||
| continue | |||
| # get all md_structs by same pattern | |||
| md_list = self._global_context.module_structs.get(pattern_id) | |||
| self._take_formal_args_from_updated_submodules(md_list) | |||
| args_translators = self.get_args_translator_from_module_structs_list(md_list) | |||
| formal_args_list = ArgsTranslationHelper.find_formal_args_in_modules(args_translators) | |||
| changed_args_translators = self.get_args_translator_from_module_structs_list( | |||
| md_list, exclude_root_son=True) | |||
| ArgsTranslationHelper.change_args_to_formal_for_all_translators( | |||
| formal_args_list, changed_args_translators) | |||
| done_submodule.add(pattern_id) | |||
| def _take_formal_args_from_updated_submodules(self, md_list): | |||
| """ | |||
| Take formal args from provided modules' nodes and submodules. | |||
| Args: | |||
| md_list (list): A list of ModuleStruct. | |||
| """ | |||
| if isinstance(md_list, ModuleStruct): | |||
| md_list = [md_list] | |||
| for md in md_list: | |||
| md.args_translator.take_formal_args_from_nodes_and_submodules(md.get_all_sub_translators()) | |||
| def _update_module_depth_max(self, nd_struct: NodeStruct): | |||
| """ | |||
| Update the Generator attribute module_depth_max, which is the maximum depth in the Model. | |||
| Args: | |||
| nd_struct (NodeStruct): NodeStruct to be checked its depth. | |||
| """ | |||
| depth = nd_struct.scope.depth | |||
| if isinstance(depth, int): | |||
| if depth > self._module_depth_max: | |||
| self._module_depth_max = depth | |||
| else: | |||
| raise TypeError("Unable to update global depth due to TypeError in NodeStruct.scope.depth") | |||
| def add_node(self, node_identifier, node_instance=None, node_fragment=None, mapper_dict=None): | |||
| """ | |||
| Add Node information to the generator. | |||
| Args: | |||
| node_identifier (str): The unique identifier for the node passed in. | |||
| node_instance (GraphNode): The GraphNode instance of each node. | |||
| node_fragment (NodeFragment): The NodeFragment instance of this node passed in. | |||
| mapper_dict (dict): The dict contains converted attributes from mapper. | |||
| """ | |||
| if node_identifier is None: | |||
| raise ValueError("Node Identifier should not be None.") | |||
| self._global_context.node_fragments[node_identifier] = node_fragment | |||
| args = [] | |||
| if node_instance is not None: | |||
| args.append(node_instance) | |||
| if mapper_dict is not None: | |||
| args.append(mapper_dict) | |||
| if node_fragment is not None: | |||
| args.append(node_fragment) | |||
| nd_struct = self.node_structs.get(node_identifier) | |||
| if nd_struct: # NodeStruct already exists | |||
| nd_struct.update(args) | |||
| else: # create new Node Struct | |||
| nd_struct = NodeStruct(args) | |||
| nd_struct.identifier = node_identifier | |||
| self._update_module_depth_max(nd_struct) | |||
| self.node_structs[node_identifier] = nd_struct | |||
| @property | |||
| def node_structs(self): | |||
| """Return all NodeStructs in this model.""" | |||
| return self._node_struct_collections | |||
| @property | |||
| def module_structs(self): | |||
| """Return all ModuleStructs in this model.""" | |||
| return self._module_struct_collections | |||
| def generate(self): | |||
| """ | |||
| Generate the final script file. | |||
| Returns: | |||
| list, a list of each line in script file. | |||
| """ | |||
| self._form_bottom_submodule() | |||
| self._recursive_form_module() | |||
| code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | |||
| return code.code_line_list | |||
| def get_node_struct(self, node_identifier): | |||
| """ | |||
| Get specific NodeStruct by node_identifier. | |||
| Args: | |||
| node_identifier (str): The node unique identifier. | |||
| Return: | |||
| NodeStruct, the node's NodeStruct. | |||
| """ | |||
| return self._node_struct_collections.get(node_identifier, None) | |||
| def get_module_struct(self, module_identifier): | |||
| """ | |||
| Get specific ModuleStruct by module_identifier. | |||
| Args: | |||
| module_identifier (str): The module unique identifier. | |||
| Return: | |||
| ModuleStruct, the node's ModuleStruct. | |||
| """ | |||
| return self._module_struct_collections.get(module_identifier, None) | |||
| def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id): | |||
| """ | |||
| Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern. | |||
| Args: | |||
| pattern_id (int): The pattern id the returned ModuleSturct is. | |||
| under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is. | |||
| Returns: | |||
| list, a list of MoudleStructs has the same pattern_id and the same parents' pattern. | |||
| """ | |||
| if not pattern_id: | |||
| raise ValueError("pattern_id is necessary to get the module struct.") | |||
| if not under_parent_pattern_id: | |||
| raise ValueError("under_parent_pattern_id is necessary to get the module struct.") | |||
| ret = [] | |||
| md_list = self._global_context.module_structs.get(pattern_id) | |||
| for md in md_list: | |||
| if md.parent_id == under_parent_pattern_id: | |||
| ret.append(md) | |||
| return ret | |||
| def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): | |||
| """ | |||
| Return a list of args translators which belongs to given module structs. | |||
| Args: | |||
| md_list (list): A list of ModuleStruct. | |||
| exclude_root_son (Bool): If the returned result should include args translator belongs to | |||
| modules under the Main module. | |||
| Returns: | |||
| list, a list of args translators which belongs to given module structs. | |||
| """ | |||
| ret = [] | |||
| for md in md_list: | |||
| if exclude_root_son and md.parent_id == -1: | |||
| continue | |||
| if md.args_translator is not None: | |||
| ret.append(md.args_translator) | |||
| return ret | |||
| def module_map_filter(self, depth=None, module_num=None, uid=None): | |||
| """ | |||
| Filter the module map by given conditions. | |||
| Args: | |||
| depth (int): Scope depth. | |||
| module_num (int): The submodule number. | |||
| uid (int): The unique identifier of a submodule. | |||
| Return: | |||
| list, list of NodeStruct list of each submodule. | |||
| """ | |||
| ret = list() | |||
| for scope_path, nd_struct_list in self._module_map.items(): | |||
| path = Scope.path_str_to_list(scope_path) | |||
| if not path: # skip main | |||
| continue | |||
| # if depth not equals to the indicated depth, skip | |||
| if depth is not None and len(path) != depth: | |||
| continue | |||
| scope_at_depth = path[-1] | |||
| (m_num, m_uid) = scope_at_depth | |||
| if uid is not None: | |||
| if m_num == module_num and m_uid == uid: | |||
| ret.append(nd_struct_list) | |||
| else: | |||
| if m_num == module_num: | |||
| ret.append(nd_struct_list) | |||
| return ret | |||
| @@ -0,0 +1,710 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define a struct for module converted and save all required information here.""" | |||
| from collections import OrderedDict | |||
| from .node_struct import NodeStruct | |||
| from .scope_utils import Scope | |||
| from ..common.utils import get_dict_key_by_value | |||
| from .args_translator import ArgsTranslation | |||
| from ..common.code_fragment import ModuleFragment | |||
| from ..common.global_context import GlobalContext | |||
| from ..hierarchical_tree.name_mgr import LocalVarNameMgr | |||
| class ModuleStruct: | |||
| """ | |||
| Define a module struct which stores all info. to generate statement. | |||
| Args: | |||
| args (list): A list of node structs. | |||
| """ | |||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||
| def __init__(self, nd_struct_list): | |||
| """Init. a module by NodeStructs.""" | |||
| self.pattern_id = -1 # pattern num, -1 as Main module | |||
| self.pattern_uid = -1 # unique module id for this pattern | |||
| self.parent_id = None # parent's pattern num | |||
| self.parent_uid = None # parent's pattern module unique id | |||
| self.initialized = False | |||
| self.identifier = None | |||
| self.module_name = None | |||
| self.scope_depth = None | |||
| self.head_nd_struct = None | |||
| self.head_nd_struct_index = None | |||
| self.tail_nd_struct = None | |||
| self.tail_nd_struct_index = None | |||
| self._node_structs = list() | |||
| self._module_structs = list() | |||
| self._fragment = None | |||
| self._args_translator = None | |||
| self._setting = None | |||
| self._parent_module_struct = None | |||
| # only store original formal args name, not global | |||
| self._nodes_structs_formal_args_list = list() | |||
| # only store translated (globalized) formal args | |||
| self._nodes_structs_formal_args_translated_list = list() | |||
| # define other settings here | |||
| self._node_args_translation_list = list() | |||
| self._var_name_mgr = LocalVarNameMgr() | |||
| self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name | |||
| self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct | |||
| self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name | |||
| # key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name) | |||
| self.outputs_collection = dict() | |||
| self.matched_inputs = list() # Matched inputs will can be directly used by code line generation | |||
| # key is ext. succ node onnx name, value is local opt_var | |||
| self.external_successor_local_returns_map = OrderedDict() | |||
| # key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level | |||
| self.outputs_collection = dict() | |||
| # start initialization | |||
| if not self.initialized: | |||
| self._init_module(nd_struct_list) | |||
| else: | |||
| self._update_module(nd_struct_list) | |||
| # assign this module reference to node | |||
| for (_, nd_struct) in nd_struct_list: | |||
| nd_struct.parent_module_struct = self | |||
| def reset_as_parent(self): | |||
| """ | |||
| Reset all attributes and filled as a parent module of this module. | |||
| Note: | |||
| This function must be called only after a deepcopy of this instance! | |||
| """ | |||
| self.identifier.pop() | |||
| self.scope_depth = self.scope_depth - 1 | |||
| self._set_pattern_id() | |||
| self._find_parent_module() | |||
| self.module_name = Scope.scope_to_module_name(self.identifier) | |||
| self._node_structs = list() | |||
| self._module_structs = list() | |||
| self._fragment = None | |||
| self._args_translator = None | |||
| self.init_args_translator() | |||
| self._setting = None | |||
| self._parent_module_struct = None | |||
| self._nodes_structs_formal_args_list = list() | |||
| self._node_args_translation_list = list() | |||
| def _set_pattern_id(self): | |||
| """Set pattern id which matches the module fragment pattern.""" | |||
| if not self.initialized: | |||
| return | |||
| if self.scope_depth < 1: | |||
| self.pattern_id = -1 | |||
| self.pattern_uid = -1 | |||
| return | |||
| self.pattern_id = self.identifier[-1][0] | |||
| self.pattern_uid = self.identifier[-1][1] | |||
| def _init_module(self, nd_struct_list): | |||
| """Init this ModuleStruct by a list of Nodes.""" | |||
| (nd_topo_idx, nd_struct) = nd_struct_list[0] | |||
| self.identifier = nd_struct.scope.path | |||
| self.module_name = nd_struct.scope.to_str | |||
| self.scope_depth = nd_struct.scope.depth | |||
| self.head_nd_struct = nd_struct | |||
| self.head_nd_struct_index = nd_topo_idx | |||
| self.tail_nd_struct = nd_struct_list[-1][1] | |||
| self.tail_nd_struct_index = nd_struct_list[-1][0] | |||
| self._node_structs = nd_struct_list | |||
| self.initialized = True | |||
| self._set_pattern_id() | |||
| self._find_parent_module() | |||
| self.init_args_translator() | |||
| def _update_module(self, nd_struct_list): | |||
| """Update the ModuleStruct attributes from a list of Nodes.""" | |||
| (nd_topo_idx_head, nd_struct_head) = nd_struct_list[0] | |||
| (nd_topo_idx_tail, nd_struct_tail) = nd_struct_list[-1] | |||
| if self.identifier != nd_struct_head.scope.path: | |||
| raise ValueError("Unable to update this module struct {} due to different identifier {}".format( | |||
| self.identifier, nd_struct_head.scope.path)) | |||
| if nd_topo_idx_head < self.head_nd_struct_index: | |||
| self.head_nd_struct_index = nd_topo_idx_head | |||
| self.head_nd_struct = nd_struct_head | |||
| if nd_topo_idx_tail > self.tail_nd_struct_index: | |||
| self.tail_nd_struct_index = nd_topo_idx_tail | |||
| self.tail_nd_struct = nd_struct_tail | |||
| self._node_structs += nd_struct_list | |||
| def _find_parent_module(self): | |||
| """Set the parent's module pattern and uid.""" | |||
| if not self.initialized: | |||
| return | |||
| if self.scope_depth == 0: # is Main Module | |||
| pass | |||
| elif self.scope_depth == 1: # parent pattern is Main module | |||
| self.parent_id = -1 | |||
| self.parent_uid = -1 | |||
| else: # this is a submodule in a module | |||
| (self.parent_id, self.parent_uid) = Scope.get_parent_module_num_and_uid( | |||
| self.identifier) | |||
| def __repr__(self): | |||
| return str({ | |||
| "address": hex(id(self)), | |||
| "identifier": self.identifier, | |||
| "parent": (self.parent_id, self.parent_uid), | |||
| "name": self.module_name, | |||
| "pattern": self.pattern_id, | |||
| "scope_depth": self.scope_depth, | |||
| "nd_idx_range": "{} -> {}".format(self.head_nd_struct_index, self.tail_nd_struct_index), | |||
| "initialized": self.initialized | |||
| }) | |||
| def init_module_fragment(self): | |||
| """Init the module fragment.""" | |||
| if not self.initialized: | |||
| return | |||
| # check if fragment exists in global context | |||
| op = "Module{}".format(self.pattern_id) | |||
| if op == "Module-1": # reset as Main Model's op name | |||
| op = "Model" | |||
| frag = GlobalContext().get_module_fragment(op) | |||
| if frag is not None: # use exists fragment | |||
| self._fragment = frag | |||
| else: | |||
| frag = ModuleFragment(operation=op, | |||
| actual_args=None, | |||
| input_shape=None, | |||
| output_shape=None, | |||
| settings=None) | |||
| self._fragment = frag | |||
| # set fragment pattern | |||
| self._fragment.pattern = self._node_structs | |||
| GlobalContext().add_module_fragment(op, frag) | |||
| def init_args_translator(self): | |||
| """Initialize the Args Translator for the module.""" | |||
| var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid) | |||
| self._args_translator = ArgsTranslation(None, var_name, None) | |||
| def update_module_fragment(self): | |||
| """Update this module's fragment.""" | |||
| if self._fragment is None: | |||
| return | |||
| # update input output shape | |||
| self._fragment.input_shape = self.head_nd_struct.fragment.input_shape | |||
| self._fragment.output_shape = self.tail_nd_struct.fragment.output_shape | |||
| # update formal args | |||
| self._fragment.formal_args.update(self._args_translator.formal_args) | |||
| self._fragment.formal_args_value.update(self._args_translator.formal_args_values) | |||
| # update actual args | |||
| self._fragment.actual_args.update(self._args_translator.actual_args) | |||
| # update others.. | |||
| def add_submodule(self, md_structs): | |||
| """ | |||
| Add another module struct(s) to this ModuleStruct. | |||
| Args: | |||
| md_structs ([ModuleStruct, list]): a (list) ModuleStruct to be added in this ModuleStruct. | |||
| """ | |||
| tail_md = md_structs | |||
| if isinstance(md_structs, ModuleStruct): | |||
| md_structs.args_translator.take_formal_args_from_nodes_and_submodules(md_structs.get_all_sub_translators()) | |||
| self._module_structs.append(md_structs) | |||
| md_structs.parent_module_struct = self | |||
| elif isinstance(md_structs, list): | |||
| for md_s in md_structs: | |||
| md_s.args_translator.take_formal_args_from_nodes_and_submodules(md_s.get_all_sub_translators()) | |||
| md_s.parent_module_struct = self | |||
| self._module_structs += md_structs | |||
| tail_md = md_structs[-1] | |||
| else: | |||
| raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format( | |||
| type(md_structs))) | |||
| # update tail node and index | |||
| if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: | |||
| self.tail_nd_struct = tail_md.tail_nd_struct | |||
| self.tail_nd_struct_index = tail_md.tail_nd_struct_index | |||
| def _update_formal_args_for_all_nd_structs(self): | |||
| """ | |||
| Init nodes' args translator and find formal args. | |||
| And collect nodes' formal args. | |||
| """ | |||
| if len(self._node_args_translation_list) != len(self._node_structs): | |||
| raise ValueError( | |||
| "ModuleStruct cannot update nodes' formal args due to length inconsistent.") | |||
| for idx, (_, nd_struct) in enumerate(self._node_structs): | |||
| formal_arg_of_this_node = self._node_args_translation_list[idx] | |||
| # update var_name to ensure all node names' are unique in a module. | |||
| nd_struct.update_var_name(idx) | |||
| nd_struct.init_args_translator(formal_arg_of_this_node) | |||
| if nd_struct.args_translator is not None: | |||
| self._nodes_structs_formal_args_list.append( | |||
| nd_struct.args_translator.formal_args_values) | |||
| else: | |||
| self._nodes_structs_formal_args_list.append(None) | |||
| def update_args_translation_list(self, formal_args): | |||
| """ | |||
| Receive a list of args name to be changed to formal args, and change them. | |||
| Args: | |||
| formal_args (list[str]): a list of args name to be changed to formal args. | |||
| """ | |||
| self._node_args_translation_list = formal_args | |||
| self._update_formal_args_for_all_nd_structs() | |||
| def get_all_sub_translators(self): | |||
| """ | |||
| Return a list of args_translators of submodules / nodes affiliated to this module. | |||
| Note: | |||
| The order of returned list is followed by the actual topological order. | |||
| Returns: | |||
| list, a list of args_translators. | |||
| """ | |||
| ret = [] | |||
| for (_, struct) in self.get_generate_order(): | |||
| if struct.args_translator is not None: | |||
| ret.append(struct.args_translator) | |||
| return ret | |||
| def get_generate_order(self): | |||
| """ | |||
| Return the order of generated code by index. | |||
| Return: | |||
| list, a list of reference of node_struct or module_struct. | |||
| """ | |||
| ret = list() | |||
| if not self._module_structs: | |||
| return self._node_structs | |||
| # Generate a list of tuple (idx, module_structs) | |||
| for md_struct in self._module_structs: | |||
| ret.append((md_struct.head_nd_struct_index, md_struct)) | |||
| if self.node_structs: | |||
| ret += self.node_structs | |||
| ret.sort(key=lambda x: x[0]) | |||
| return ret | |||
| def code_line_in_init(self): | |||
| """ | |||
| Initialization line of code in module init block. | |||
| Args: | |||
| override_formal_val (dict): Indicate which args should be renamed for passing value from upper level. | |||
| """ | |||
| left = "self.{}".format(self.ms_var_name) | |||
| args_list = list() | |||
| # Load args in init statement. | |||
| if self._args_translator is not None: # from args_translator | |||
| if self._args_translator.actual_args: # load actual args | |||
| args_list += self._args_translator.actual_args_to_str_list | |||
| elif self._args_translator.actual_args_backup and self.parent_id == -1: | |||
| # For modules repeated in multiple levels, the module under main model should | |||
| # not use formal args as it is unnecessary -> load from actual args backup | |||
| args_list += self._args_translator.actual_args_backup_to_str_list | |||
| args_list += self._args_translator.formal_args_to_str_list # load from formal args | |||
| else: | |||
| args_list += self._fragment.actual_args | |||
| right = f"{self.class_name}({', '.join(args_list)})" | |||
| return (left, right) | |||
| def code_line_in_construct(self, inputs=None): | |||
| """Construct line of code in module construct block.""" | |||
| # check number of outputs this module has | |||
| opt_var_name_in_module = list(self.external_successor_local_returns_map.values()) | |||
| num_output = len(set(opt_var_name_in_module)) | |||
| if num_output == 1: # single output | |||
| left = f"{self.ms_opt_var_name}" | |||
| else: | |||
| left = [f"{self.ms_opt_var_name}_{num}" for num in range(num_output)] | |||
| if inputs is None and self.matched_inputs: | |||
| inputs = self.matched_inputs | |||
| if isinstance(inputs, str): | |||
| inputs = [inputs] | |||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | |||
| return (left, right) | |||
| @property | |||
| def node_structs(self): | |||
| """Return all node structs in this module.""" | |||
| return self._node_structs | |||
| @property | |||
| def module_structs(self): | |||
| """Return all module structs in this module.""" | |||
| return self._module_structs | |||
| @property | |||
| def parent_module_struct(self): | |||
| """Return this module's parent module struct.""" | |||
| return self._parent_module_struct | |||
| @parent_module_struct.setter | |||
| def parent_module_struct(self, ref): | |||
| """Set this modu;e's parent module struct.""" | |||
| self._parent_module_struct = ref | |||
| @property | |||
| def args_translator(self): | |||
| """Return the args translator.""" | |||
| return self._args_translator | |||
| @property | |||
| def head_nd_struct_precursor_nodes_names(self) -> list: | |||
| """Return head node's precursor nodes names.""" | |||
| return self.head_nd_struct.precursor_nodes_names | |||
| @property | |||
| def head_nd_struct_precursor_nodes_structs(self) -> list: | |||
| """Return head node's precursor nodes structs.""" | |||
| return self.head_nd_struct.precursor_nodes_structs | |||
| @property | |||
| def tail_nd_struct_successor_nodes_names(self) -> list: | |||
| """Return tail node's successor nodes names.""" | |||
| return self.tail_nd_struct.successor_nodes_names | |||
| @property | |||
| def tail_nd_struct_successor_nodes_structs(self) -> list: | |||
| """Return tail node's successor nodes structs.""" | |||
| return self.tail_nd_struct.successor_nodes_structs | |||
| @property | |||
| def onnx_names_from_nodes(self) -> list: | |||
| """Return all nodes onnx names in this module.""" | |||
| ret = [] | |||
| for (_, node) in self.node_structs: | |||
| ret.append(node.onnx_name) | |||
| return ret | |||
| @property | |||
| def onnx_names_from_submodules(self) -> list: | |||
| """Return all nodes onnx names in submodules of this module.""" | |||
| ret = [] | |||
| for md_struct in self.module_structs: | |||
| ret += md_struct.onnx_names | |||
| return ret | |||
| @property | |||
| def onnx_names(self) -> list: | |||
| """Return all nodes' onnx names which contained by this module.""" | |||
| return self.onnx_names_from_nodes + self.onnx_names_from_submodules | |||
| @property | |||
| def external_precursor_nodes_names(self) -> list: | |||
| """Return all precursors nodes names not in this module.""" | |||
| ret = [] | |||
| for _, struct in self.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): | |||
| precursor_nodes_names = struct.precursor_nodes_names | |||
| if isinstance(struct, ModuleStruct): | |||
| precursor_nodes_names = struct.external_precursor_nodes_names | |||
| for p_name in precursor_nodes_names: | |||
| if p_name in self.onnx_names: | |||
| continue | |||
| ret.append(p_name) | |||
| return ret | |||
| @property | |||
| def external_successor_nodes_names(self) -> list: | |||
| """Return all precursors nodes names not in this module.""" | |||
| ret = [] | |||
| for _, struct in self.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): | |||
| successor_nodes_names = struct.successor_nodes_names | |||
| if isinstance(struct, ModuleStruct): | |||
| successor_nodes_names = struct.external_successor_nodes_names | |||
| for s_name in successor_nodes_names: | |||
| if s_name in self.onnx_names: | |||
| continue | |||
| ret.append(s_name) | |||
| return ret | |||
| @property | |||
| def class_name(self) -> str: | |||
| """Return the class name for generating code of this module.""" | |||
| if self.pattern_id == -1: | |||
| return "Model" | |||
| return "Module{}".format(self.pattern_id) | |||
| @property | |||
| def ms_var_name(self) -> str: | |||
| """Return the variable name for generated code statement of this module.""" | |||
| if self.pattern_id == -1: | |||
| return "Model" | |||
| return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower() | |||
| @property | |||
| def ms_opt_var_name(self) -> str: | |||
| """Return the variable name for generated code statement of the output of this module.""" | |||
| return "{}_opt".format(self.ms_var_name).lower() | |||
| # The following part will be resetting nodes' external inputs for supporting multi-in/out | |||
| # and should be called after generator.recursive_form_modules() | |||
| def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): | |||
| """ | |||
| Mark the registered external inputs for code generation. | |||
| Note: | |||
| This function to be called by its parent (ModuleStruct). | |||
| Args: | |||
| header_x (str): The `x` in module construct header. | |||
| onnx_precursor_node_name (str): The original onnx node name. | |||
| """ | |||
| if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None: | |||
| raise ValueError("The input from {} has already registered. Check this Module \ | |||
| {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) | |||
| self.inputs_in_construct_header[onnx_precursor_node_name] = header_x | |||
| def allocate_construct_header_x(self, force_x=None): | |||
| """ | |||
| Allocate the x in construct header for each external input. | |||
| Args: | |||
| force_x (str): Force the arg name to customized. | |||
| """ | |||
| local_x_name = 'x' | |||
| if force_x: # name of x indicated by external | |||
| local_x_name = force_x | |||
| # set construct_header_x for current module | |||
| allocated = set() | |||
| for prec_name in self.external_precursor_nodes_names: | |||
| if prec_name in allocated: | |||
| continue | |||
| x_name_in_construct_header = self._var_name_mgr.get_name(local_x_name) | |||
| self.construct_header_x[x_name_in_construct_header] = prec_name | |||
| allocated.add(prec_name) | |||
| # Assign these inputs to nodes and submodules | |||
| for _, struct in self.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): # register node's ext input | |||
| self.reset_node_external_input_to_local(struct) | |||
| self.register_node_output_to_module(struct) | |||
| if isinstance(struct, ModuleStruct): # reg module's ext input | |||
| if not struct.construct_header_x: | |||
| struct.allocate_construct_header_x() | |||
| self.reset_submodule_external_input_to_local(struct) | |||
| self.register_submodule_output_to_module(struct) | |||
| # remove parent module's ext. map if ext nodes in this module (no need return) | |||
| for user_name in self.external_successor_local_returns_map.copy().keys(): | |||
| if user_name in self.onnx_names: | |||
| self.external_successor_local_returns_map.pop(user_name) | |||
| def _match_node_inputs(self, struct): | |||
| """Match node's inputs with its precursor nodes.""" | |||
| for output_provider in struct.precursor_nodes_names: | |||
| output_list = self.outputs_collection.get(output_provider) | |||
| if output_list is None: | |||
| # not in this module, check construct header | |||
| for (self_x_name, self_output_provider) in self.construct_header_x.items(): | |||
| if self_output_provider == output_provider: | |||
| struct.matched_inputs.append(self_x_name) | |||
| continue | |||
| for output in output_list: | |||
| (provider_succ, provider_closet_opt_var) = output | |||
| if provider_closet_opt_var in struct.matched_inputs: | |||
| continue # skip repeat | |||
| if provider_succ == struct.onnx_name: | |||
| struct.matched_inputs.append(provider_closet_opt_var) | |||
| def _match_sub_modules_inputs(self): | |||
| """ | |||
| Match current module's submodules' inputs with corresponding outputs registered in current module. | |||
| Description: | |||
| The function matches these inputs by the following steps: | |||
| 1. For each submodule in the current module, take submodule's construct header | |||
| 2. Check submodule's construct header element requires an input from current module's | |||
| construct header or outputs from other submodules. | |||
| 3. If from current module's construct header, assign corresponding x to the submodule. | |||
| If from other submodules, assign required submodule output name to the submodule. | |||
| """ | |||
| if not self.outputs_collection: | |||
| return # skip first node | |||
| for (_, struct) in self.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): | |||
| self._match_node_inputs(struct) | |||
| continue # skip node | |||
| sub_construct_header = struct.construct_header_x | |||
| for (_, output_provider) in sub_construct_header.items(): | |||
| # check from outputs collection | |||
| output_list = self.outputs_collection.get(output_provider) | |||
| if output_list is None: | |||
| # not in this module, need from current module construct header | |||
| for (self_x_name, self_output_provider) in self.construct_header_x.items(): | |||
| if self_output_provider == output_provider: | |||
| struct.matched_inputs.append(self_x_name) | |||
| continue | |||
| for output in output_list: | |||
| (provider_succ, provider_closet_opt_var) = output | |||
| if provider_closet_opt_var in struct.matched_inputs: | |||
| continue # skip repeat | |||
| if provider_succ in struct.onnx_names: | |||
| struct.matched_inputs.append(provider_closet_opt_var) | |||
| def _append_to_outputs_collection(self, provider_name, val): | |||
| """ | |||
| Helper function to add a nodes or submodules outputs to current module return statement. | |||
| Args: | |||
| provider_name (str): The onnx name of the output provider. | |||
| val (list[tuple]): A list of tuple which contains | |||
| the output provider's successor name and its opt_var_name. | |||
| """ | |||
| exist_output = self.outputs_collection.get(provider_name) | |||
| if isinstance(val, tuple): | |||
| val = [val] | |||
| if exist_output is None: # add new entry | |||
| exist_output = list() | |||
| exist_output += (val) | |||
| self.outputs_collection[provider_name] = exist_output | |||
| def collect_returns(self): | |||
| """ | |||
| Collect all nodes and submodules' returns in the module. | |||
| Note: | |||
| The logic is to collect the return from nodes and submodules by the order | |||
| of topological index. | |||
| For returns from a node, it will check if the return will be used externally. | |||
| If external (external means the successor a.k.a the return user has different scope with the node), | |||
| add this return to current module's outputs_collection, where | |||
| key is this node's original onnx_name, and value is a list of | |||
| tuple(successor_name, this node's opt_var_name) | |||
| For returns from a submodule, it will check if the submodule has already collected returns, | |||
| If not, do it and then continue the following procedures. | |||
| Now we will check each element in submodule's outputs_collection. Note that we DO NOT check submodule's | |||
| returns should be continued returning, but just return them. | |||
| All these returns from submodules will be changes their original nodes output (a.k.a outputs provider) | |||
| `opt_var_name` to submodules' `opt_var_name`. | |||
| Finally, we match the outputs and inputs in the current module level. | |||
| """ | |||
| for (_, struct) in self.get_generate_order(): | |||
| if isinstance(struct, NodeStruct): | |||
| outputs_list = [] | |||
| # add these successor nodes name to collection for future use | |||
| for succ in struct.successor_nodes_names: | |||
| outputs_list.append((succ, struct.ms_opt_var_name)) | |||
| if outputs_list: | |||
| self._append_to_outputs_collection(struct.onnx_name, outputs_list) | |||
| if isinstance(struct, ModuleStruct): | |||
| # Remove unnecessary returns, succ are all inside current | |||
| if not struct.outputs_collection: | |||
| struct.collect_returns() | |||
| sub_outputs_collection = struct.outputs_collection | |||
| # check each returns in sub | |||
| for provider_name, outputs_list in sub_outputs_collection.items(): | |||
| for output in outputs_list: | |||
| (succ, _) = output # (succ, provider_opt_var_name) in output | |||
| new_output = (succ, struct.ms_opt_var_name) | |||
| self._append_to_outputs_collection(provider_name, new_output) | |||
| self._match_sub_modules_inputs() | |||
| def get_returned_opt_var_name(self) -> list: | |||
| """Return a list of returned output var of this module.""" | |||
| idx = 0 | |||
| added_to_return = set() | |||
| ret = [] | |||
| for ext_successor_requested, opt_var_name_in_this_module in self.external_successor_local_returns_map.items(): | |||
| if ext_successor_requested in added_to_return: | |||
| continue | |||
| ret.append((ext_successor_requested, opt_var_name_in_this_module, idx)) | |||
| added_to_return.add(ext_successor_requested) | |||
| return ret | |||
| def reset_node_external_input_to_local(self, nd_struct): | |||
| """ | |||
| Reset node's input to module's construct args | |||
| """ | |||
| for prec_node_name in nd_struct.precursor_nodes_names_external: | |||
| if prec_node_name in self.onnx_names: # prec node in current module's. | |||
| continue | |||
| if prec_node_name in self.construct_header_x.values(): | |||
| # prec node assigned to construct header to passed in. | |||
| local_x = get_dict_key_by_value(prec_node_name, self.construct_header_x) | |||
| nd_struct.set_inputs_in_construct_header(local_x, prec_node_name) | |||
| else: # Extra precursor nodes, raise error | |||
| raise ValueError("Found external inputs of the Node but the module does not have it.") | |||
| def reset_submodule_external_input_to_local(self, md_struct): | |||
| """ | |||
| Reset submodule's external input to current module. | |||
| Args: | |||
| md_struct (ModuleStruct): The submodule in the current module. | |||
| """ | |||
| # check submodule's input | |||
| for _, submodule_precursor in md_struct.construct_header_x.items(): | |||
| if submodule_precursor in self.onnx_names: # if internal, match with local nodes/submodules return | |||
| # but do nothing here | |||
| continue | |||
| else: # if external, match with current module construct header x | |||
| if submodule_precursor in self.construct_header_x.values(): | |||
| local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x) | |||
| md_struct.set_inputs_in_construct_header(local_x, submodule_precursor) | |||
| else: # Extra precursor nodes, raise error | |||
| raise ValueError("Found external inputs of the submodule but the module does not have it.") | |||
| def register_node_output_to_module(self, nd_struct): | |||
| """Register nodes outputs to this module's return.""" | |||
| for succ_node_name in nd_struct.successor_nodes_names_external: | |||
| self.external_successor_local_returns_map[succ_node_name] = nd_struct.ms_opt_var_name | |||
| def register_submodule_output_to_module(self, md_struct): | |||
| """Register submodule outputs to this module's return.""" | |||
| submodule_returns = md_struct.get_returned_opt_var_name() | |||
| submodule_opt_var_name = md_struct.ms_opt_var_name | |||
| for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns: | |||
| self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) | |||
| # edit external succ 's inputs in parent module | |||
| ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ) | |||
| ext_node_parent = ext_node.parent_module_struct | |||
| while ext_node_parent != self.parent_module_struct: | |||
| ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name | |||
| ext_node_parent = ext_node_parent.parent_module_struct | |||
| # need find the prec_name? | |||
| for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items(): | |||
| if isinstance(opt_var_name, str): | |||
| if opt_var_name == opt_var_name_in_this_module: | |||
| ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) | |||
| if isinstance(opt_var_name, tuple): | |||
| if opt_var_name[0] == opt_var_name_in_this_module: | |||
| ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) | |||
| @@ -0,0 +1,423 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define the NodeStruct which stores all info. of a node.""" | |||
| from collections import OrderedDict | |||
| from .scope_utils import Scope | |||
| from .args_translator import ArgsTranslation | |||
| from ..common.code_fragment import CodeFragment | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from ..common.global_context import GlobalContext | |||
| from ..constant import InputType | |||
| class NodeStruct: | |||
| """ | |||
| Define a node struct which stores all info. to generate statement. | |||
| Args: | |||
| args (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | |||
| Note: | |||
| You can pass as many args as possible and the Node Struct will update | |||
| by arguments order. | |||
| """ | |||
| GLOBAL_CONTEXT_MGR = GlobalContext() | |||
| def __init__(self, args): | |||
| # define attributes here | |||
| self._identifier = None | |||
| self._fragment = None | |||
| self._args_translator = None | |||
| self._parent_module_struct = None | |||
| self.topo_idx = None | |||
| self.node_type = None | |||
| self.onnx_name = None | |||
| self.onnx_op = None | |||
| self.graph_node_ref = None # Our defined GraphNode | |||
| self.scope_name = None | |||
| self.ms_var_name = None | |||
| self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...) | |||
| self.ms_op = None | |||
| self.ready_to_generate = False | |||
| self.ms_params = dict() # converted params from mapper | |||
| self.ms_settings = dict() | |||
| self.ms_weights = dict() | |||
| self.ms_inputs = OrderedDict() | |||
| self.scope = None # Defined Scope class | |||
| self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use | |||
| self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name | |||
| self.matched_inputs = list() # Matched inputs will can be directly used by code line generation | |||
| # initialize funcs. | |||
| for arg in args: | |||
| self.update(arg) | |||
| def __repr__(self): | |||
| return str({ | |||
| "address": hex(id(self)), | |||
| "idx": self.topo_idx, | |||
| "identifier": self.identifier | |||
| }) | |||
| def ori_topo_idx(self): | |||
| """Get the original topological index in the onnx graph.""" | |||
| ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/') | |||
| self.onnx_name = ori_name | |||
| return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) | |||
| def update_var_name(self, idx=None): | |||
| """ | |||
| Update the var_name of each node. | |||
| Args: | |||
| idx (int): The index of the node in this module. | |||
| """ | |||
| if idx is not None: | |||
| self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx) | |||
| elif self.topo_idx is not None: | |||
| self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx) | |||
| else: | |||
| raise ValueError("Unable to update var name when topo_idx is None.") | |||
| self.ms_opt_var_name = self.ms_var_name + '_opt' | |||
| def _update_basics_from_gn(self, gn): | |||
| """Update basic info from GraphNode.""" | |||
| self.graph_node_ref = gn | |||
| self.scope_name = gn.scope_name | |||
| def _update_from_pytorch_gn(self, gn: PyTorchGraphNode): | |||
| """Update basic info from PyTorchGraphNode.""" | |||
| self.node_type = "PyTorchGraphNode" | |||
| self._update_basics_from_gn(gn) | |||
| def _update_from_onnx_gn(self, gn: OnnxGraphNode): | |||
| """Update basic info from OnnxGraphNode.""" | |||
| self.node_type = "OnnxGraphNode" | |||
| self._update_basics_from_gn(gn) | |||
| def _update_from_mapper(self, d): | |||
| """Update info from mapper.""" | |||
| if d.get('op_name'): | |||
| self.ms_op = d.get('op_name') | |||
| if d.get('params'): | |||
| self.ms_params = d.get('params') | |||
| if d.get('settings'): | |||
| self.ms_settings = d.get('settings') | |||
| if d.get('weights'): | |||
| self.ms_weights = d.get('weights') | |||
| def _update_from_fragment(self, frag: CodeFragment): | |||
| """Update info from CodeFragment.""" | |||
| self._fragment = frag | |||
| if frag.operation: | |||
| self.ms_op = frag.operation | |||
| idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count | |||
| self.update_var_name(idx=idx) | |||
| def _set_scope_from_identifier(self): | |||
| """Set the Node scope from identifier.""" | |||
| parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) | |||
| self.scope = Scope(parsed_scope) | |||
| def init_args_translator(self, translated_args: list): | |||
| """ | |||
| Initialize the ArgsTranslator for each Node. | |||
| Args: | |||
| translated_args (list): The list of args should be translated to formal args. | |||
| """ | |||
| if not self._fragment: | |||
| raise ValueError("Initialize argument translator failed.") | |||
| if self._fragment.actual_args and translated_args: | |||
| self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args) | |||
| def check_if_generate_ready(self): | |||
| """Check if the NodeStruct is able to generate code.""" | |||
| # check essential params exists | |||
| if all([self.identifier, | |||
| self.node_type, | |||
| self.scope_name, | |||
| self.ms_var_name, | |||
| self.ms_opt_var_name, | |||
| self.ms_op]): | |||
| self.ready_to_generate = True | |||
| def update(self, arg, force_ready=False): | |||
| """ | |||
| Pass Node info. to generator NodeStruct. | |||
| Args: | |||
| arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. | |||
| force_ready (bool): Force this NodeStruct is ready to generate. | |||
| """ | |||
| if isinstance(arg, PyTorchGraphNode): | |||
| self._update_from_pytorch_gn(arg) | |||
| elif isinstance(arg, OnnxGraphNode): | |||
| self._update_from_onnx_gn(arg) | |||
| elif isinstance(arg, (dict, OrderedDict)): | |||
| self._update_from_mapper(arg) | |||
| elif isinstance(arg, CodeFragment): | |||
| self._update_from_fragment(arg) | |||
| else: | |||
| raise TypeError("NodeStruct received an unsupported initializing argument.") | |||
| if force_ready: | |||
| self.ready_to_generate = True | |||
| else: | |||
| self.check_if_generate_ready() | |||
| @property | |||
| def identifier(self): | |||
| """Return the identifier of the node.""" | |||
| return self._identifier | |||
| @identifier.setter | |||
| def identifier(self, s): | |||
| """ | |||
| Set the Node identifier, and update the scope. | |||
| Args: | |||
| s (str): The node identifier string. | |||
| """ | |||
| self._identifier = s | |||
| self._set_scope_from_identifier() | |||
| self.topo_idx = self.ori_topo_idx() | |||
| self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self | |||
| @property | |||
| def fragment(self): | |||
| """Return the fragment of the node.""" | |||
| return self._fragment | |||
| @fragment.setter | |||
| def fragment(self, frag): | |||
| """ | |||
| Set the Node fragment. | |||
| Args: | |||
| s (NodeFragment): The node identifier string. | |||
| """ | |||
| self._fragment = frag | |||
| @property | |||
| def graph_node(self): | |||
| """Return the GraphNode reference.""" | |||
| return self.graph_node_ref | |||
| @graph_node.setter | |||
| def graph_node(self, graphnode): | |||
| """Set the GraphNode reference.""" | |||
| self.graph_node_ref = graphnode | |||
| @property | |||
| def onnx_node(self): | |||
| """Return the original onnx node reference.""" | |||
| return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) | |||
| @property | |||
| def args_translator(self): | |||
| """Return the args translator of this Node.""" | |||
| return self._args_translator | |||
| @property | |||
| def precursor_nodes_names(self) -> list: | |||
| """Return the names of precursor nodes.""" | |||
| return self.graph_node_ref.precursor_nodes | |||
| @property | |||
| def precursor_nodes_structs(self) -> list: | |||
| """Return the node struct instances of precursor nodes.""" | |||
| ret = [] | |||
| precursor_nodes_names = self.precursor_nodes_names | |||
| for pre_node_name in precursor_nodes_names: | |||
| nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| ret.append(nd_struct) | |||
| return ret | |||
| @property | |||
| def successor_nodes_names(self) -> list: | |||
| """Return the names of successor nodes.""" | |||
| return self.graph_node_ref.successor_nodes | |||
| @property | |||
| def successor_nodes_structs(self) -> list: | |||
| """Return the node struct instances of successor nodes.""" | |||
| ret = [] | |||
| for pre_node_name in self.successor_nodes_names: | |||
| nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name) | |||
| ret.append(nd_struct) | |||
| return ret | |||
| @property | |||
| def parent_module_struct(self): | |||
| """Return the parent struct of this node.""" | |||
| return self._parent_module_struct | |||
| @parent_module_struct.setter | |||
| def parent_module_struct(self, ref): | |||
| self._parent_module_struct = ref | |||
| # Code Generation funcs below | |||
| def code_line_in_init(self): | |||
| """Initialization line of code in module init block.""" | |||
| left = "self.{}".format(self.ms_var_name) | |||
| args_list = list() | |||
| if self._args_translator is not None: | |||
| args_list += self._args_translator.actual_args_to_str_list | |||
| args_list += self._args_translator.formal_args_to_str_list | |||
| else: | |||
| actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args) | |||
| args_list += actual_args_str | |||
| right = f"{self.ms_op}({', '.join(args_list)})" | |||
| return left, right | |||
| def _get_correct_in_module_returns(self, prec_node, in_module_return): | |||
| """ | |||
| Find the correct precursor node name in return statement of its parent module. | |||
| Args: | |||
| prec_node (str): The onnx name of the precursor node given. | |||
| in_module_return (list[tuple]): The list of outputs which contains parent module identifier | |||
| and module opt_var_name. | |||
| Return: | |||
| str, correct opt_var_name to be passed in current node. | |||
| """ | |||
| found_return = False | |||
| for ret in in_module_return: | |||
| (md_identifier, input_name_to_use) = ret | |||
| p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node) | |||
| # recursive check the p node parent | |||
| parent = p_node_struct | |||
| while not found_return: | |||
| parent = parent.parent_module_struct | |||
| if parent is None: | |||
| break | |||
| if parent.identifier == md_identifier: | |||
| return input_name_to_use | |||
| return None | |||
| def code_line_in_construct(self, inputs=None, in_module_returns=None): | |||
| """Construct line of code in module construct block. """ | |||
| left = self.ms_opt_var_name | |||
| if inputs is None: | |||
| inputs = [] | |||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | |||
| if self.inputs_in_construct_header.get(prec_node): | |||
| inputs.append(self.inputs_in_construct_header.get(prec_node)) | |||
| elif self._check_target_node_internal(prec_node): | |||
| inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name) | |||
| elif self.inputs_in_parent_module.get(prec_node): | |||
| inputs.append(self.inputs_in_parent_module.get(prec_node)) | |||
| elif in_module_returns and in_module_returns.get(self.onnx_name) \ | |||
| and (not self._check_target_node_internal(prec_node)): | |||
| inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name))) | |||
| else: | |||
| inputs.append("unk_{}_{}".format(idx, prec_node)) | |||
| if self.matched_inputs: | |||
| inputs = self.matched_inputs | |||
| # Check original onnx node's input to ensure double inputs are not ignored | |||
| original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name) | |||
| new_inputs = [] | |||
| for idx, prec_node in enumerate(self.precursor_nodes_names): | |||
| occurence = original_inputs.count(prec_node) | |||
| for _ in range(occurence): | |||
| new_inputs.append(inputs[idx]) | |||
| inputs = new_inputs | |||
| if isinstance(inputs, str): | |||
| inputs = [inputs] | |||
| if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value: | |||
| inputs = [str(tuple(inputs)).replace("\'", "")] | |||
| if self._fragment.code_setting and self._fragment.code_setting.op_extra_input: | |||
| for _, val in self._fragment.code_setting.op_extra_input.items(): | |||
| inputs.append(str(val)) | |||
| if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor: | |||
| inputs.append(f"self.{self.ms_var_name}_w") | |||
| right = f"self.{self.ms_var_name}({', '.join(inputs)})" | |||
| return left, right | |||
| def add_extra_tensor(self): | |||
| """ Add extra tensor.""" | |||
| left = "self.{}_w".format(self.ms_var_name) | |||
| shape = self._fragment.code_setting.op_extra_tensor.shape | |||
| right = f"Tensor(np.random.uniform(0, 1, {shape}), mindspore.float32)" | |||
| return left, right | |||
| # The following functions are specified for multiple in/out support. | |||
| # and should be called only after generator._recursive_form_modules() | |||
| def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): | |||
| """ | |||
| Mark the registered external inputs for code generation. | |||
| Note: | |||
| This function to be called by its parent (ModuleStruct). | |||
| Args: | |||
| header_x (str): The `x` in module construct header. | |||
| onnx_precursor_node_name (str): The original onnx node name. | |||
| """ | |||
| if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None: | |||
| raise ValueError("The input from {} has already registered. Check this node \ | |||
| {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) | |||
| self.inputs_in_construct_header[onnx_precursor_node_name] = header_x | |||
| def _check_target_node_internal(self, name: str) -> bool: | |||
| """ | |||
| Check given node under the same scope. | |||
| Args: | |||
| name (str): Can accept both node identifier or original onnx node name. | |||
| """ | |||
| target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \ | |||
| or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name) | |||
| if target_nd_struct is None and self.topo_idx == 0: # First node always has external input | |||
| return False | |||
| if target_nd_struct is None: | |||
| raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name)) | |||
| return target_nd_struct.scope.path == self.scope.path | |||
| @property | |||
| def has_successor_node_external(self) -> bool: | |||
| """Check if any successor_node is in external module.""" | |||
| for name in self.successor_nodes_names: | |||
| if not self._check_target_node_internal(name): | |||
| return False | |||
| return True | |||
| @property | |||
| def precursor_nodes_names_external(self) -> list: | |||
| """Return a list of external precursor nodes names.""" | |||
| return [name for name in self.precursor_nodes_names \ | |||
| if not self._check_target_node_internal(name)] | |||
| @property | |||
| def successor_nodes_names_external(self) -> list: | |||
| """Return a list of external successor nodes names.""" | |||
| return [name for name in self.successor_nodes_names \ | |||
| if not self._check_target_node_internal(name)] | |||
| @@ -0,0 +1,157 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define a scope class processing all operations related to scope and scope name.""" | |||
| import re | |||
| class Scope(): | |||
| """Define scope related operations.""" | |||
| def __init__(self, scope_str): | |||
| scopes = scope_str.split('/') | |||
| self.module_path = list() | |||
| self.scope_list = scopes[:-1] | |||
| self.head = self.scope_list[0] | |||
| self.tail = self.scope_list[-1] | |||
| self.initialization() | |||
| def initialization(self): | |||
| """Init scope class.""" | |||
| self._update_module_path_from_scope_list() | |||
| def _update_module_path_from_scope_list(self): | |||
| """Update the module scope path from a list of scope.""" | |||
| self.module_path = list() | |||
| for scope in self.scope_list: | |||
| if scope == 'Model': | |||
| continue | |||
| if 'Module' in scope: | |||
| regex = r"Module(?P<num>\d+)_(?P<curr_level_unique_id>\d+)" | |||
| match = re.match(regex, scope) | |||
| if match: | |||
| module_num = match.group('num') | |||
| uid = match.group('curr_level_unique_id') | |||
| self.module_path.append((int(module_num), int(uid))) | |||
| @property | |||
| def path(self): | |||
| """Return module scope path.""" | |||
| return self.module_path | |||
| def set_path(self, ind, path_tuple: tuple): | |||
| """ | |||
| Set the module scope path. | |||
| Args: | |||
| ind (int): The index of the scope path to be set. | |||
| path_tuple ((int, int)): The tuple of the scope path. | |||
| """ | |||
| self.module_path[ind] = path_tuple | |||
| @property | |||
| def to_str(self): | |||
| """Return the full module scope as the string format.""" | |||
| full_str_list = ["Model"] | |||
| for (num, uid) in self.module_path: | |||
| local = "Module{}_{}".format(num, uid) | |||
| full_str_list.append(local) | |||
| return "/".join(full_str_list) | |||
| @property | |||
| def depth(self): | |||
| """Return the depth of the scope path.""" | |||
| return len(self.path) | |||
| @staticmethod | |||
| def scope_to_module_name(path): | |||
| """ | |||
| Helper function to convert any scope path string to the full module scope. | |||
| Args: | |||
| path (str): path string like "[(5, 0), (3, 0)]" | |||
| Returns: | |||
| str, the full module scope with format like "Model/Module5_0/Module3_0/" | |||
| """ | |||
| scope_str_list = ["Model"] | |||
| if isinstance(path, str): | |||
| path = Scope.path_str_to_list(path) | |||
| if isinstance(path, list): | |||
| for (num, uid) in path: | |||
| local = "Module{}_{}".format(num, uid) | |||
| scope_str_list.append(local) | |||
| return "/".join(scope_str_list) | |||
| @staticmethod | |||
| def parse_scope_from_node_identifier(node_identifier: str): | |||
| """ | |||
| Helper function to parse the scope string from node identifier. | |||
| Args: | |||
| node_identifier (str): The string of the node identifier. | |||
| Returns: | |||
| str, parsed scope string from node identifier. | |||
| """ | |||
| regex = r"(?P<scope>Model/.*)\$\S+\$" | |||
| match = re.match(regex, node_identifier) | |||
| if not match: | |||
| return None | |||
| return match.group('scope') | |||
| @staticmethod | |||
| def path_str_to_list(scope_path_str: str): | |||
| """ | |||
| Helper function to convert the scope path string back to list. | |||
| Args: | |||
| scope_path_str (str): The scope path string like "[(5, 0), (3, 0)]". | |||
| Returns: | |||
| list, a list of the scope path like [(5, 0), (3, 0)]. | |||
| """ | |||
| ret = [] | |||
| tmp = scope_path_str.strip('[').strip(']') | |||
| regex = r"\((?P<num>\d+), (?P<uid>\d+)\)" | |||
| s_all = re.findall(regex, tmp) | |||
| for (num, uid) in s_all: | |||
| ret.append((int(num), int(uid))) | |||
| return ret | |||
| @staticmethod | |||
| def get_parent_module_num_and_uid(path): | |||
| """ | |||
| Helper function to return its parent's scope tuple. | |||
| Args: | |||
| path (Union[str, list]): Module scope path string. e.g. "[(5, 0), (3, 0)]" | |||
| Returns: | |||
| tuple, parent's scope level. e.g. [(5, 0)] | |||
| """ | |||
| if isinstance(path, str): | |||
| path = Scope.path_str_to_list(path) | |||
| if isinstance(path, list): | |||
| if len(path) == 1: # modules under the main module, (-1, -1) means main module. | |||
| return (-1, -1) | |||
| if len(path) > 1: # modules under another non-main module. Return parent's scope. | |||
| parent = path[-2] | |||
| return parent | |||
| return None | |||
| @@ -106,3 +106,51 @@ class GlobalVarNameMgr: | |||
| global_var_namespace.add(new_name) | |||
| return new_name | |||
| class LocalVarNameMgr: | |||
| """Local variable name mgr.""" | |||
| def __init__(self): | |||
| self.local_op_namespace = dict() | |||
| self.local_var_namespace = set() | |||
| @staticmethod | |||
| def _get_name(name): | |||
| """Deal with op name.""" | |||
| if "::" in name: | |||
| return name.split("::")[1] | |||
| return name | |||
| def get_name(self, op_type): | |||
| """ | |||
| Get module/variable name. | |||
| If the module already existed, then add a suffix to it. | |||
| conv1 onnx::conv | |||
| Args: | |||
| op_type (str): Operator type in onnx. | |||
| Returns: | |||
| str, module name. | |||
| """ | |||
| def _gen(t): | |||
| t = t.lower() | |||
| if t not in self.local_op_namespace: | |||
| self.local_op_namespace[t] = START_IDX | |||
| suffix = "" | |||
| else: | |||
| self.local_op_namespace[t] += 1 | |||
| suffix = f"{self.local_op_namespace[t] - 1}" | |||
| return f"{self._get_name(t)}{suffix}" | |||
| new_name = _gen(op_type) | |||
| while new_name in self.local_var_namespace: | |||
| new_name = _gen(op_type) | |||
| self.local_var_namespace.add(new_name) | |||
| return new_name | |||
| @@ -151,7 +151,7 @@ class OnnxGraph(Graph): | |||
| input_shape (tuple): Input shape. | |||
| """ | |||
| input_node = InputNode(input_shape) | |||
| input_node_name = "{}InputNode" | |||
| input_node_name = self._raw_input_nodes.replace(":0", "") | |||
| for node_name, node in self._nodes_collection.items(): | |||
| if node_name in self._input_nodes: | |||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | |||
| @@ -23,6 +23,7 @@ import numpy as np | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from ..common.utils import fetch_output_from_onnx_model | |||
| from ..common.global_context import GlobalContext | |||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | |||
| @@ -110,6 +111,7 @@ class OnnxTensor: | |||
| self.to_nodes = [] | |||
| def to_array(self): | |||
| """Convert the tensor value from binary to np array.""" | |||
| onnx = import_module("onnx") | |||
| # Convert binary data to np.array | |||
| if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)): | |||
| @@ -264,7 +266,7 @@ class OnnxDataLoader: | |||
| self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes] | |||
| # args for init | |||
| self._is_infer_shape = infer_shape | |||
| self._global_context = GlobalContext() | |||
| # params parsed in init | |||
| self.inferred_model = None | |||
| @@ -375,12 +377,19 @@ class OnnxDataLoader: | |||
| def _parse_nodes(self): | |||
| """Parse each onnx nodes in the model.""" | |||
| for node in self.nodes: | |||
| nodes_topo_idx = [] | |||
| for idx, node in enumerate(self.nodes): | |||
| n = OnnxNode(node) | |||
| self._nodes_dict[n.name] = n | |||
| nodes_topo_idx.append((idx, n.name)) | |||
| if len(node.output) > 1: | |||
| raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") | |||
| self.output_name_to_node_name[node.output[0]] = node.name | |||
| self._global_context.onnx_node_name_to_topo_idx[n.name] = idx | |||
| node_inputs = [i.replace(":0", "") for i in node.input] | |||
| self._global_context.onnx_node_inputs[n.name] = node_inputs | |||
| self._global_context.onnx_nodes_collection = self._nodes_dict | |||
| self._global_context.onnx_nodes_topo_index = nodes_topo_idx | |||
| def _parse_tensors(self): | |||
| """Parse each onnx tensors in the model.""" | |||
| @@ -388,6 +397,7 @@ class OnnxDataLoader: | |||
| for tensor in tensors: | |||
| t = OnnxTensor(tensor) | |||
| self.tensors_dict[t.name] = t | |||
| self._global_context.onnx_tensors_collection = self.tensors_dict | |||
| def _parse_node_output_shape(self): | |||
| """ | |||