Browse Source

Enhance scope name generation effect.

tags/v1.1.0
liuchongming 5 years ago
parent
commit
f49082b494
8 changed files with 223 additions and 39 deletions
  1. +0
    -2
      mindinsight/mindconverter/graph_based_converter/constant.py
  2. +12
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  3. +32
    -4
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  4. +73
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py
  5. +6
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  6. +4
    -4
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern_fuzzy_matching.py
  7. +53
    -9
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  8. +43
    -20
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 0
- 2
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -27,8 +27,6 @@ FIRST_LEVEL_INDENT = BLANK_SYM * 4
SECOND_LEVEL_INDENT = BLANK_SYM * 8
NEW_LINE = "\n"

MINI_FREQUENCY = 4

ONNX_TYPE_INT = 2
ONNX_TYPE_INTS = 7
ONNX_TYPE_FLOAT = 1


+ 12
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -22,6 +22,10 @@ def register_pattern(ptn_name, in_degree, out_degree):
"""
Register pattern to MindConverter.

Notes:
The `out_degree` of pattern refers to the out-edge number in original graph,
not the output number of the pattern.

Args:
out_degree: Out degree of pattern.
in_degree: In degree of pattern.
@@ -64,4 +68,12 @@ def _conv_bn_conv_bn_relu():
return ["Conv", "BatchNormalization", "Conv", "BatchNormalization", "Relu"]


@register_pattern("ConvBnReLUx2+ConvBn+Add+Relu", 1, 2)
def _convbnrelux3_convbn_add_relu():
"""Add pattern."""
return ["Conv", "BatchNormalization", "Relu",
"Conv", "BatchNormalization", "Relu",
"Conv", "BatchNormalization", "Add", "Relu"]


__all__ = ["BUILT_IN_PATTERN", "register_pattern"]

+ 32
- 4
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -21,6 +21,10 @@ from typing import List
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode

MAX_OUT_DEGREE = 1
MINI_FREQUENCY = 4
MAX_ITERATION_DEPTH = 4
SATISFIED_SCORE = 0.55
ACCEPTABLE_RESULT_COUNT = 16


class CmpRelation:
@@ -33,9 +37,14 @@ class CmpRelation:
GREATER = 1


def gen_hash_key(sequence: List[BaseNode], separator="->"):
def gen_hash_key(sequence: List[BaseNode], separator="-", without_module: bool = False):
"""Generate hash key."""
seq = [item.op_type for item in sequence]
seq = []
for item in sequence:
if without_module and "module" in item.op_type.lower():
seq.append("_M_")
continue
seq.append(item.op_type)
return separator.join(seq)


@@ -71,6 +80,7 @@ class AlgorithmContext:
visited = set()
beam_width = 5
total_len = 0
MIN_FREQUENCY = 1
node_collection = None
precursor_table = {}
successor_table = {}
@@ -120,7 +130,21 @@ class AlgorithmContext:
reverse=True)
if len(pattern_arr) > self.beam_width:
pattern_arr = pattern_arr[:self.beam_width]
return OrderedDict(pattern_arr)
res = OrderedDict()
for i, (key, ptn) in enumerate(pattern_arr):
if ptn.count <= self.MIN_FREQUENCY:
continue
skip = False
for j, (_, candidate) in enumerate(pattern_arr):
if i == j:
continue
if candidate.ptn_length >= ptn.ptn_length and ptn.ptn_items == candidate.ptn_items[:ptn.ptn_length]:
skip = True
break
if skip:
continue
res[key] = ptn
return res


context = AlgorithmContext()
@@ -128,4 +152,8 @@ context = AlgorithmContext()
__all__ = ["context",
"gen_hash_key",
"DagGraph",
"MAX_OUT_DEGREE"]
"MAX_OUT_DEGREE",
"MAX_ITERATION_DEPTH",
"SATISFIED_SCORE",
"MINI_FREQUENCY",
"ACCEPTABLE_RESULT_COUNT"]

+ 73
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/known_module_name.py View File

@@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Introduce some standard pattern name into MindConverter."""
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern

PLACEHOLDER = "PLC"
BUILT_IN_MODULE_NAME = dict()


def register_module_name(md_name: str, in_degree: int, out_degree: int):
"""
Register pattern to MindConverter.

Args:
out_degree (int): Out degree of pattern.
in_degree (int): In degree of pattern.
md_name (str): Module name.

"""

def _reg(pattern):
result = pattern()
if not result:
return
BUILT_IN_MODULE_NAME[Pattern("->".join(result), len(result),
in_degree, out_degree,
ptn_items=result)] = md_name

return _reg


@register_module_name("Bottleneck", 1, 2)
def _resnet_block_0():
"""Add ResNet feature extraction block pattern."""
return ["Conv", "BatchNormalization", "Relu",
"Conv", "BatchNormalization", "Relu",
"Conv", "BatchNormalization", "Add", "Relu"]


@register_module_name("Bottleneck", 1, 2)
def _resnet_block_1():
"""Add ResNet feature extraction block pattern."""
return [PLACEHOLDER, PLACEHOLDER, "Conv", "BatchNormalization", "Add", "Relu"]


@register_module_name("Bottleneck", 1, 2)
def _resnet_block_2():
"""Add ResNet feature extraction block pattern."""
return [PLACEHOLDER, PLACEHOLDER, PLACEHOLDER, "Add", "Relu"]


@register_module_name("BasicConvBlock", 1, 1)
def _basic_conv_block_0():
"""Add basic conv block."""
return ["Conv", "BatchNormalization", "Relu"]


@register_module_name("ConvBN", 1, 1)
def _conv_bn():
"""Add basic conv block."""
return ["Conv", "BatchNormalization"]

+ 6
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -29,6 +29,8 @@ class Pattern:
self.ptn_items = pattern.split("->") if ptn_items is None else ptn_items
self.in_degree = in_degree
self.out_degree = out_degree
self.head = self.ptn_items[0]
self.tail = self.ptn_items[-1]

def insert(self, idx, seq_len):
"""
@@ -53,3 +55,7 @@ class Pattern:
return f"Ptn: {self.pattern}[" \
f"{scope_name_mapping.get(self.pattern, 'Not init')}], " \
f"count={self.count}"

def __hash__(self):
"""Make Pattern hashable."""
return hash(f"{self.pattern}_{self.in_degree}_{self.out_degree}")

+ 4
- 4
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern_fuzzy_matching.py View File

@@ -18,7 +18,7 @@ from typing import List
import numpy as np

MIN_PATTERN_LEN = 3
MATCHED_THRESHOLD = .75
MATCHED_THRESHOLD = .8
COMPLETELY_MATCHED = 1.


@@ -72,11 +72,11 @@ def pattern_fuzzy_matching(query: List[str], target: List[str]):
target (list): Target pattern.

Returns:
bool, true or false.
Tuple[bool, float], true or false and matching score.
"""
edit_count = _levenshtein_distance(query, target)
target_len = float(len(target))
score = (target_len - edit_count) / target_len
if target_len <= MIN_PATTERN_LEN:
return score == COMPLETELY_MATCHED
return score >= MATCHED_THRESHOLD
return score == COMPLETELY_MATCHED, score
return score >= MATCHED_THRESHOLD, score

+ 53
- 9
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -18,11 +18,14 @@ import uuid
from typing import Dict, List, Callable, Union
from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE
from .known_module_name import BUILT_IN_MODULE_NAME
from .pattern import Pattern, scope_name_mapping
from .built_in_pattern import BUILT_IN_PATTERN
from .pattern_fuzzy_matching import pattern_fuzzy_matching
from ..third_party_graph.onnx_utils import OnnxNode, BaseNode

module_name_to_src = {}
used_module_name = dict()
global_idx = 0


@@ -82,8 +85,31 @@ def _is_valid_pattern(pattern, dag):
return True


def generate_module_name():
"""Generate module name."""
def generate_module_name(pattern):
"""
Generate module name.

Args:
pattern (Pattern): To be replaced pattern.

"""
matched_result = []
for ptn, module_name in BUILT_IN_MODULE_NAME.items():
if pattern.in_degree == ptn.in_degree and pattern.out_degree == ptn.out_degree and \
ptn.head == pattern.head and ptn.tail == pattern.tail:
is_matched, score = pattern_fuzzy_matching(pattern.ptn_items, ptn.ptn_items)
if is_matched:
matched_result.append((module_name, score))
if matched_result:
module_name = (matched_result if len(matched_result) == 1 else
sorted(matched_result, key=lambda x: x[1], reverse=True))[0][0]
if pattern.pattern not in used_module_name:
used_module_name[pattern.pattern] = 1
else:
module_name = f"{module_name}{used_module_name[pattern.pattern]}"
used_module_name[pattern.pattern] += 1
return module_name

global global_idx
name = f"Module{global_idx}"
global_idx += 1
@@ -126,7 +152,7 @@ def _find_idx(sequence: List[BaseNode], target: str, equal_func: Callable,
"""
not_found = -1
if not sequence:
msg = f"Empty sequence is not supported."
msg = "Empty sequence is not supported."
raise ValueError(msg)

end_idx = end_idx if end_idx else len(sequence)
@@ -263,6 +289,7 @@ def find_built_in_pattern(topo_order: List[BaseNode], dag: DagGraph) -> Dict[str
cur_idx, total_len = 0, len(topo_order)
for k in BUILT_IN_PATTERN:
ptn_len = BUILT_IN_PATTERN[k].ptn_length
cur_idx = 0
while cur_idx < total_len:
matched = True
init_pattern = OrderedDict()
@@ -393,6 +420,10 @@ class SearchPath:
)

self.new_pattern = context.found_pattern[self.hash_of_aft_repl]
self._created_modules = {
path.pattern.module_name: path.pattern for path in self.recursion_path
}
self._created_modules[self.pattern.module_name] = self.pattern
self.heuristic_v = self._heuristic_val()
self.actual_v = self._actual_val()

@@ -405,7 +436,7 @@ class SearchPath:
to recover the sequence.
"""
if self.pattern.pattern not in scope_name_mapping:
module_name = generate_module_name()
module_name = generate_module_name(self.pattern)
scope_name_mapping[self.pattern.pattern] = module_name
module_name_to_src[module_name] = self.pattern.pattern
else:
@@ -542,14 +573,26 @@ class SearchPath:

def evaluate_score(self):
"""Evaluate path score."""
return self.actual_v + self.heuristic_v
return .7 * self.actual_v + .3 * self.heuristic_v

def _cal_merged_module_length(self, ptn):
"""Calculate module length."""
ptn_len = 0
for item in ptn.ptn_items:
if item in self._created_modules:
ptn_len += self._cal_merged_module_length(self._created_modules[item])
continue
ptn_len += 1
return ptn_len

def _heuristic_val(self):
"""Calculate heuristic score of the path."""
res = []
for _, ptn in enumerate(self.new_pattern.items()):
res.append(ptn[1].count * ptn[1].ptn_length / len(self.topo_order_aft_repl))
return sum(res) / len(res)
for ptn in self.new_pattern.items():
res.append(ptn[1].count * self._cal_merged_module_length(ptn[1]) / context.get_sequence_length())
if not res:
return 1.0
return max(res)

def _actual_val(self):
"""Calculate ground-truth score of the path."""
@@ -580,6 +623,7 @@ class SearchPath:
chain.append(sub_module)
return "->".join(chain)

repr_str = f"{self.pattern.pattern}[{self.pattern.module_name}], H: {self.heuristic_v}, G: {self.actual_v}"
repr_str = f"{self.pattern.pattern}[{self.pattern.module_name}], " \
f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}"

return repr_str

+ 43
- 20
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -16,8 +16,8 @@
from queue import PriorityQueue
from typing import Dict, List

from .common import context, DagGraph, gen_hash_key
from ..constant import MINI_FREQUENCY
from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT
from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE
from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern

@@ -32,10 +32,13 @@ def _is_satisfied(path):
Returns:
bool, True or False.
"""
if len(path.recursion_path) == 2:
if len(path.recursion_path) > MAX_ITERATION_DEPTH:
return True
flag = [cur_pattern.count for _, cur_pattern in path.new_pattern.items()]
return float(sum(flag)) / len(flag) == 1 or path.actual_v >= 0.80
if not path.new_pattern or max([p.count for _, p in path.new_pattern.items()]) < MINI_FREQUENCY:
return True
if path.evaluate_score() > SATISFIED_SCORE:
return True
return False


def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
@@ -74,14 +77,16 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
available_path.append(cur_path)
continue

if len(available_path) >= 8:
if len(available_path) >= ACCEPTABLE_RESULT_COUNT:
break

for _, cur_pattern in cur_path.new_pattern.items():
if cur_pattern.count < MINI_FREQUENCY:
available_path.append(cur_path)
break
key = "/".join([cur_pattern.pattern, gen_hash_key(cur_topo_order)])
continue
key = "/".join([f"{cur_pattern.pattern}[{cur_pattern.in_degree},{cur_pattern.out_degree}]",
gen_hash_key(cur_topo_order, without_module=True)])
if key in context.visited:
continue
# c. create new SearchPath.
new_path = SearchPath(pattern=cur_pattern, sequence=cur_topo_order, prev_path=cur_path,
sub_graph_size=sub_graph_size)
@@ -107,18 +112,17 @@ def _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=4):
context.set_beam_width(beam_width)

def _get_top_1(available_path: list):
if len(available_path) <= 1:
return available_path
if not available_path:
return None
available_path = sorted(available_path, key=lambda x: x.actual_v, reverse=True)
return available_path[0] if available_path else None
return available_path[0]

topo_order = [node for _, (_, node) in enumerate(context.node_collection.items())]
context.set_sequence_length(len(topo_order))
built_in_pattern = find_built_in_pattern(topo_order, init_dag)
pattern = generate_pattern(topo_order, dag=init_dag, sub_graph_size=sub_graph_size)
pattern.update(built_in_pattern)
found_path = _search(pattern, topo_order, init_graph=init_dag,
sub_graph_size=sub_graph_size)
found_path = _search(pattern, topo_order, init_graph=init_dag, sub_graph_size=2)
return _get_top_1(found_path)


@@ -183,7 +187,7 @@ def _retrieve_operators(module_path, module_dict):
str: module_name, operators in module.
"""
added_module = dict()
node_in_pattern = module_path.pattern.pattern.split('->')
node_in_pattern = module_path.pattern.ptn_items
node_list = []
for node in node_in_pattern:
if module_dict.get(node):
@@ -192,9 +196,8 @@ def _retrieve_operators(module_path, module_dict):
added_module)
else:
node_list.append(node)
key = module_path.pattern.module_name
val = [f"{key}/{node}" for node in node_list]
return key, val
val = [f"{module_path.pattern.module_name}/{node}" for node in node_list]
return module_path.pattern.module_name, val


def _build_connection(loader):
@@ -216,6 +219,19 @@ def _build_connection(loader):
return dag


def flatten_graph(graph):
"""
Flatten graph into a sequence.

Args:
graph (DagGraph): DagGraph instance.

Returns:
list[str], corresponding scope name.
"""
return [f"Model/{node.op_type}" for _, node in graph.node_collection.items()]


def generate_scope_name(data_loader):
"""
Generate scope name according to computation graph.
@@ -227,6 +243,13 @@ def generate_scope_name(data_loader):
list[str], generated scope name.
"""
init_dag = _build_connection(data_loader)
result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6)
topo_order_with_scope_name_list = _retrieve_scope_name(result)
try:
result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6)
topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag)

if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict):
topo_order_with_scope_name_list = flatten_graph(init_dag)

except (ValueError, IndexError, AttributeError, KeyError) as _:
topo_order_with_scope_name_list = flatten_graph(init_dag)
return topo_order_with_scope_name_list

Loading…
Cancel
Save