Browse Source

Fix bugs in scope name generation, generate names recursively.

tags/v1.1.0
liuchongming 5 years ago
parent
commit
9775235cde
2 changed files with 17 additions and 14 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py
  2. +16
    -13
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -27,7 +27,7 @@ FIRST_LEVEL_INDENT = BLANK_SYM * 4
SECOND_LEVEL_INDENT = BLANK_SYM * 8
NEW_LINE = "\n"

MINI_FREQUENCEY = 4
MINI_FREQUENCY = 4

ONNX_TYPE_INT = 2
ONNX_TYPE_INTS = 7


+ 16
- 13
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -17,13 +17,10 @@ from queue import PriorityQueue
from typing import Dict, List

from .common import context, DagGraph, gen_hash_key
from ..constant import MINI_FREQUENCEY
from ..constant import MINI_FREQUENCY
from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern

# Hold module name of current graph.
module_name_mgr = dict()


def _is_satisfied(path):
"""
@@ -81,7 +78,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
break

for _, cur_pattern in cur_path.new_pattern.items():
if cur_pattern.count < MINI_FREQUENCEY:
if cur_pattern.count < MINI_FREQUENCY:
available_path.append(cur_path)
break
key = "/".join([cur_pattern.pattern, gen_hash_key(cur_topo_order)])
@@ -130,6 +127,8 @@ def _retrieve_scope_name(found_path):
Args:
found_path: Found path.
"""
module_name_mgr = dict()

module_dict = dict()
for module_path in found_path.recursion_path:
key, val = _retrieve_operators(module_path, module_dict)
@@ -142,29 +141,31 @@ def _retrieve_scope_name(found_path):
for node in found_path.topo_order_aft_repl:
if module_dict.get(node.op_type):
topo_order_with_scope_name += [f"Model/{item}" for item in _scope_name_deduplication(
node.op_type, module_dict[node.op_type])]
node.op_type, module_dict[node.op_type], module_name_mgr)]
else:
topo_order_with_scope_name.append(f"Model/{node.op_type}")
return topo_order_with_scope_name


def _scope_name_deduplication(key, scope_names) -> list:
def _scope_name_deduplication(key, scope_names, memo) -> list:
"""
Scope name deduplication.

Args:
key (str): Module name.
scope_names (list): Scope names.
memo (dict): Memo to record module name.

Returns:
list, renamed scope name.
"""
result = []
if key not in module_name_mgr:
module_name_mgr[key] = 0
if key not in memo:
memo[key] = 0
for item in scope_names:
item = item.replace(key, f"{key}_{module_name_mgr.get(key)}")
item = item.replace(key, f"{key}_{memo.get(key)}")
result.append(item)
module_name_mgr[key] += 1
memo[key] += 1
return result


@@ -179,12 +180,14 @@ def _retrieve_operators(module_path, module_dict):
Returns:
str: module_name, operators in module.
"""
global module_name_mgr
added_module = dict()
node_in_pattern = module_path.pattern.pattern.split('->')
node_list = []
for node in node_in_pattern:
if module_dict.get(node):
node_list += module_dict[node]
node_list += _scope_name_deduplication(node,
module_dict[node],
added_module)
else:
node_list.append(node)
key = module_path.pattern.module_name


Loading…
Cancel
Save