|
|
|
@@ -17,13 +17,10 @@ from queue import PriorityQueue |
|
|
|
from typing import Dict, List |
|
|
|
|
|
|
|
from .common import context, DagGraph, gen_hash_key |
|
|
|
from ..constant import MINI_FREQUENCEY |
|
|
|
from ..constant import MINI_FREQUENCY |
|
|
|
from ..third_party_graph.onnx_utils import BaseNode |
|
|
|
from .search_path import SearchPath, Pattern, generate_pattern |
|
|
|
|
|
|
|
# Hold module name of current graph. |
|
|
|
module_name_mgr = dict() |
|
|
|
|
|
|
|
|
|
|
|
def _is_satisfied(path): |
|
|
|
""" |
|
|
|
@@ -81,7 +78,7 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], |
|
|
|
break |
|
|
|
|
|
|
|
for _, cur_pattern in cur_path.new_pattern.items(): |
|
|
|
if cur_pattern.count < MINI_FREQUENCEY: |
|
|
|
if cur_pattern.count < MINI_FREQUENCY: |
|
|
|
available_path.append(cur_path) |
|
|
|
break |
|
|
|
key = "/".join([cur_pattern.pattern, gen_hash_key(cur_topo_order)]) |
|
|
|
@@ -130,6 +127,8 @@ def _retrieve_scope_name(found_path): |
|
|
|
Args: |
|
|
|
found_path: Found path. |
|
|
|
""" |
|
|
|
module_name_mgr = dict() |
|
|
|
|
|
|
|
module_dict = dict() |
|
|
|
for module_path in found_path.recursion_path: |
|
|
|
key, val = _retrieve_operators(module_path, module_dict) |
|
|
|
@@ -142,29 +141,31 @@ def _retrieve_scope_name(found_path): |
|
|
|
for node in found_path.topo_order_aft_repl: |
|
|
|
if module_dict.get(node.op_type): |
|
|
|
topo_order_with_scope_name += [f"Model/{item}" for item in _scope_name_deduplication( |
|
|
|
node.op_type, module_dict[node.op_type])] |
|
|
|
node.op_type, module_dict[node.op_type], module_name_mgr)] |
|
|
|
else: |
|
|
|
topo_order_with_scope_name.append(f"Model/{node.op_type}") |
|
|
|
return topo_order_with_scope_name |
|
|
|
|
|
|
|
|
|
|
|
def _scope_name_deduplication(key, scope_names) -> list: |
|
|
|
def _scope_name_deduplication(key, scope_names, memo) -> list: |
|
|
|
""" |
|
|
|
Scope name deduplication. |
|
|
|
|
|
|
|
Args: |
|
|
|
key (str): Module name. |
|
|
|
scope_names (list): Scope names. |
|
|
|
memo (dict): Memo to record module name. |
|
|
|
|
|
|
|
Returns: |
|
|
|
list, renamed scope name. |
|
|
|
""" |
|
|
|
result = [] |
|
|
|
if key not in module_name_mgr: |
|
|
|
module_name_mgr[key] = 0 |
|
|
|
if key not in memo: |
|
|
|
memo[key] = 0 |
|
|
|
for item in scope_names: |
|
|
|
item = item.replace(key, f"{key}_{module_name_mgr.get(key)}") |
|
|
|
item = item.replace(key, f"{key}_{memo.get(key)}") |
|
|
|
result.append(item) |
|
|
|
module_name_mgr[key] += 1 |
|
|
|
memo[key] += 1 |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
@@ -179,12 +180,14 @@ def _retrieve_operators(module_path, module_dict): |
|
|
|
Returns: |
|
|
|
str: module_name, operators in module. |
|
|
|
""" |
|
|
|
global module_name_mgr |
|
|
|
added_module = dict() |
|
|
|
node_in_pattern = module_path.pattern.pattern.split('->') |
|
|
|
node_list = [] |
|
|
|
for node in node_in_pattern: |
|
|
|
if module_dict.get(node): |
|
|
|
node_list += module_dict[node] |
|
|
|
node_list += _scope_name_deduplication(node, |
|
|
|
module_dict[node], |
|
|
|
added_module) |
|
|
|
else: |
|
|
|
node_list.append(node) |
|
|
|
key = module_path.pattern.module_name |
|
|
|
|