diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index cc7ee4bc..0f6dfe85 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -27,7 +27,7 @@ FIRST_LEVEL_INDENT = BLANK_SYM * 4 SECOND_LEVEL_INDENT = BLANK_SYM * 8 NEW_LINE = "\n" -MINI_FREQUENCEY = 4 +MINI_FREQUENCY = 4 ONNX_TYPE_INT = 2 ONNX_TYPE_INTS = 7 diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index 8eb6e256..1e935e7f 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -17,13 +17,10 @@ from queue import PriorityQueue from typing import Dict, List from .common import context, DagGraph, gen_hash_key -from ..constant import MINI_FREQUENCEY +from ..constant import MINI_FREQUENCY from ..third_party_graph.onnx_utils import BaseNode from .search_path import SearchPath, Pattern, generate_pattern -# Hold module name of current graph. -module_name_mgr = dict() - def _is_satisfied(path): """ @@ -81,7 +78,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], break for _, cur_pattern in cur_path.new_pattern.items(): - if cur_pattern.count < MINI_FREQUENCEY: + if cur_pattern.count < MINI_FREQUENCY: available_path.append(cur_path) break key = "/".join([cur_pattern.pattern, gen_hash_key(cur_topo_order)]) @@ -130,6 +127,8 @@ def _retrieve_scope_name(found_path): Args: found_path: Found path. """ + module_name_mgr = dict() + module_dict = dict() for module_path in found_path.recursion_path: key, val = _retrieve_operators(module_path, module_dict) @@ -142,29 +141,31 @@ def _retrieve_scope_name(found_path): for node in found_path.topo_order_aft_repl: if module_dict.get(node.op_type): topo_order_with_scope_name += [f"Model/{item}" for item in _scope_name_deduplication( - node.op_type, module_dict[node.op_type])] + node.op_type, module_dict[node.op_type], module_name_mgr)] else: topo_order_with_scope_name.append(f"Model/{node.op_type}") return topo_order_with_scope_name -def _scope_name_deduplication(key, scope_names) -> list: +def _scope_name_deduplication(key, scope_names, memo) -> list: """ Scope name deduplication. Args: + key (str): Module name. scope_names (list): Scope names. + memo (dict): Memo to record module name. Returns: list, renamed scope name. """ result = [] - if key not in module_name_mgr: - module_name_mgr[key] = 0 + if key not in memo: + memo[key] = 0 for item in scope_names: - item = item.replace(key, f"{key}_{module_name_mgr.get(key)}") + item = item.replace(key, f"{key}_{memo.get(key)}") result.append(item) - module_name_mgr[key] += 1 + memo[key] += 1 return result @@ -179,12 +180,14 @@ def _retrieve_operators(module_path, module_dict): Returns: str: module_name, operators in module. """ - global module_name_mgr + added_module = dict() node_in_pattern = module_path.pattern.pattern.split('->') node_list = [] for node in node_in_pattern: if module_dict.get(node): - node_list += module_dict[node] + node_list += _scope_name_deduplication(node, + module_dict[node], + added_module) else: node_list.append(node) key = module_path.pattern.module_name