| @@ -27,8 +27,6 @@ FIRST_LEVEL_INDENT = BLANK_SYM * 4 | |||
| SECOND_LEVEL_INDENT = BLANK_SYM * 8 | |||
| NEW_LINE = "\n" | |||
| MINI_FREQUENCY = 4 | |||
| ONNX_TYPE_INT = 2 | |||
| ONNX_TYPE_INTS = 7 | |||
| ONNX_TYPE_FLOAT = 1 | |||
| @@ -22,6 +22,10 @@ def register_pattern(ptn_name, in_degree, out_degree): | |||
| """ | |||
| Register pattern to MindConverter. | |||
| Notes: | |||
| The `out_degree` of pattern refers to the out-edge number in original graph, | |||
| not the output number of the pattern. | |||
| Args: | |||
| out_degree: Out degree of pattern. | |||
| in_degree: In degree of pattern. | |||
| @@ -64,4 +68,12 @@ def _conv_bn_conv_bn_relu(): | |||
| return ["Conv", "BatchNormalization", "Conv", "BatchNormalization", "Relu"] | |||
| @register_pattern("ConvBnReLUx2+ConvBn+Add+Relu", 1, 2) | |||
| def _convbnrelux3_convbn_add_relu(): | |||
| """Add pattern.""" | |||
| return ["Conv", "BatchNormalization", "Relu", | |||
| "Conv", "BatchNormalization", "Relu", | |||
| "Conv", "BatchNormalization", "Add", "Relu"] | |||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern"] | |||
| @@ -21,6 +21,10 @@ from typing import List | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | |||
| MAX_OUT_DEGREE = 1 | |||
| MINI_FREQUENCY = 4 | |||
| MAX_ITERATION_DEPTH = 4 | |||
| SATISFIED_SCORE = 0.55 | |||
| ACCEPTABLE_RESULT_COUNT = 16 | |||
| class CmpRelation: | |||
| @@ -33,9 +37,14 @@ class CmpRelation: | |||
| GREATER = 1 | |||
| def gen_hash_key(sequence: List[BaseNode], separator="->"): | |||
| def gen_hash_key(sequence: List[BaseNode], separator="-", without_module: bool = False): | |||
| """Generate hash key.""" | |||
| seq = [item.op_type for item in sequence] | |||
| seq = [] | |||
| for item in sequence: | |||
| if without_module and "module" in item.op_type.lower(): | |||
| seq.append("_M_") | |||
| continue | |||
| seq.append(item.op_type) | |||
| return separator.join(seq) | |||
| @@ -71,6 +80,7 @@ class AlgorithmContext: | |||
| visited = set() | |||
| beam_width = 5 | |||
| total_len = 0 | |||
| MIN_FREQUENCY = 1 | |||
| node_collection = None | |||
| precursor_table = {} | |||
| successor_table = {} | |||
| @@ -120,7 +130,21 @@ class AlgorithmContext: | |||
| reverse=True) | |||
| if len(pattern_arr) > self.beam_width: | |||
| pattern_arr = pattern_arr[:self.beam_width] | |||
| return OrderedDict(pattern_arr) | |||
| res = OrderedDict() | |||
| for i, (key, ptn) in enumerate(pattern_arr): | |||
| if ptn.count <= self.MIN_FREQUENCY: | |||
| continue | |||
| skip = False | |||
| for j, (_, candidate) in enumerate(pattern_arr): | |||
| if i == j: | |||
| continue | |||
| if candidate.ptn_length >= ptn.ptn_length and ptn.ptn_items == candidate.ptn_items[:ptn.ptn_length]: | |||
| skip = True | |||
| break | |||
| if skip: | |||
| continue | |||
| res[key] = ptn | |||
| return res | |||
| context = AlgorithmContext() | |||
| @@ -128,4 +152,8 @@ context = AlgorithmContext() | |||
| __all__ = ["context", | |||
| "gen_hash_key", | |||
| "DagGraph", | |||
| "MAX_OUT_DEGREE"] | |||
| "MAX_OUT_DEGREE", | |||
| "MAX_ITERATION_DEPTH", | |||
| "SATISFIED_SCORE", | |||
| "MINI_FREQUENCY", | |||
| "ACCEPTABLE_RESULT_COUNT"] | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Introduce some standard pattern name into MindConverter.""" | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | |||
| PLACEHOLDER = "PLC" | |||
| BUILT_IN_MODULE_NAME = dict() | |||
| def register_module_name(md_name: str, in_degree: int, out_degree: int): | |||
| """ | |||
| Register pattern to MindConverter. | |||
| Args: | |||
| out_degree (int): Out degree of pattern. | |||
| in_degree (int): In degree of pattern. | |||
| md_name (str): Module name. | |||
| """ | |||
| def _reg(pattern): | |||
| result = pattern() | |||
| if not result: | |||
| return | |||
| BUILT_IN_MODULE_NAME[Pattern("->".join(result), len(result), | |||
| in_degree, out_degree, | |||
| ptn_items=result)] = md_name | |||
| return _reg | |||
| @register_module_name("Bottleneck", 1, 2) | |||
| def _resnet_block_0(): | |||
| """Add ResNet feature extraction block pattern.""" | |||
| return ["Conv", "BatchNormalization", "Relu", | |||
| "Conv", "BatchNormalization", "Relu", | |||
| "Conv", "BatchNormalization", "Add", "Relu"] | |||
| @register_module_name("Bottleneck", 1, 2) | |||
| def _resnet_block_1(): | |||
| """Add ResNet feature extraction block pattern.""" | |||
| return [PLACEHOLDER, PLACEHOLDER, "Conv", "BatchNormalization", "Add", "Relu"] | |||
| @register_module_name("Bottleneck", 1, 2) | |||
| def _resnet_block_2(): | |||
| """Add ResNet feature extraction block pattern.""" | |||
| return [PLACEHOLDER, PLACEHOLDER, PLACEHOLDER, "Add", "Relu"] | |||
| @register_module_name("BasicConvBlock", 1, 1) | |||
| def _basic_conv_block_0(): | |||
| """Add basic conv block.""" | |||
| return ["Conv", "BatchNormalization", "Relu"] | |||
| @register_module_name("ConvBN", 1, 1) | |||
| def _conv_bn(): | |||
| """Add basic conv block.""" | |||
| return ["Conv", "BatchNormalization"] | |||
| @@ -29,6 +29,8 @@ class Pattern: | |||
| self.ptn_items = pattern.split("->") if ptn_items is None else ptn_items | |||
| self.in_degree = in_degree | |||
| self.out_degree = out_degree | |||
| self.head = self.ptn_items[0] | |||
| self.tail = self.ptn_items[-1] | |||
| def insert(self, idx, seq_len): | |||
| """ | |||
| @@ -53,3 +55,7 @@ class Pattern: | |||
| return f"Ptn: {self.pattern}[" \ | |||
| f"{scope_name_mapping.get(self.pattern, 'Not init')}], " \ | |||
| f"count={self.count}" | |||
| def __hash__(self): | |||
| """Make Pattern hashable.""" | |||
| return hash(f"{self.pattern}_{self.in_degree}_{self.out_degree}") | |||
| @@ -18,7 +18,7 @@ from typing import List | |||
| import numpy as np | |||
| MIN_PATTERN_LEN = 3 | |||
| MATCHED_THRESHOLD = .75 | |||
| MATCHED_THRESHOLD = .8 | |||
| COMPLETELY_MATCHED = 1. | |||
| @@ -72,11 +72,11 @@ def pattern_fuzzy_matching(query: List[str], target: List[str]): | |||
| target (list): Target pattern. | |||
| Returns: | |||
| bool, true or false. | |||
| Tuple[bool, float], true or false and matching score. | |||
| """ | |||
| edit_count = _levenshtein_distance(query, target) | |||
| target_len = float(len(target)) | |||
| score = (target_len - edit_count) / target_len | |||
| if target_len <= MIN_PATTERN_LEN: | |||
| return score == COMPLETELY_MATCHED | |||
| return score >= MATCHED_THRESHOLD | |||
| return score == COMPLETELY_MATCHED, score | |||
| return score >= MATCHED_THRESHOLD, score | |||
| @@ -18,11 +18,14 @@ import uuid | |||
| from typing import Dict, List, Callable, Union | |||
| from collections import OrderedDict | |||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE | |||
| from .known_module_name import BUILT_IN_MODULE_NAME | |||
| from .pattern import Pattern, scope_name_mapping | |||
| from .built_in_pattern import BUILT_IN_PATTERN | |||
| from .pattern_fuzzy_matching import pattern_fuzzy_matching | |||
| from ..third_party_graph.onnx_utils import OnnxNode, BaseNode | |||
| module_name_to_src = {} | |||
| used_module_name = dict() | |||
| global_idx = 0 | |||
| @@ -82,8 +85,31 @@ def _is_valid_pattern(pattern, dag): | |||
| return True | |||
| def generate_module_name(): | |||
| """Generate module name.""" | |||
| def generate_module_name(pattern): | |||
| """ | |||
| Generate module name. | |||
| Args: | |||
| pattern (Pattern): To be replaced pattern. | |||
| """ | |||
| matched_result = [] | |||
| for ptn, module_name in BUILT_IN_MODULE_NAME.items(): | |||
| if pattern.in_degree == ptn.in_degree and pattern.out_degree == ptn.out_degree and \ | |||
| ptn.head == pattern.head and ptn.tail == pattern.tail: | |||
| is_matched, score = pattern_fuzzy_matching(pattern.ptn_items, ptn.ptn_items) | |||
| if is_matched: | |||
| matched_result.append((module_name, score)) | |||
| if matched_result: | |||
| module_name = (matched_result if len(matched_result) == 1 else | |||
| sorted(matched_result, key=lambda x: x[1], reverse=True))[0][0] | |||
| if pattern.pattern not in used_module_name: | |||
| used_module_name[pattern.pattern] = 1 | |||
| else: | |||
| module_name = f"{module_name}{used_module_name[pattern.pattern]}" | |||
| used_module_name[pattern.pattern] += 1 | |||
| return module_name | |||
| global global_idx | |||
| name = f"Module{global_idx}" | |||
| global_idx += 1 | |||
| @@ -126,7 +152,7 @@ def _find_idx(sequence: List[BaseNode], target: str, equal_func: Callable, | |||
| """ | |||
| not_found = -1 | |||
| if not sequence: | |||
| msg = f"Empty sequence is not supported." | |||
| msg = "Empty sequence is not supported." | |||
| raise ValueError(msg) | |||
| end_idx = end_idx if end_idx else len(sequence) | |||
| @@ -263,6 +289,7 @@ def find_built_in_pattern(topo_order: List[BaseNode], dag: DagGraph) -> Dict[str | |||
| cur_idx, total_len = 0, len(topo_order) | |||
| for k in BUILT_IN_PATTERN: | |||
| ptn_len = BUILT_IN_PATTERN[k].ptn_length | |||
| cur_idx = 0 | |||
| while cur_idx < total_len: | |||
| matched = True | |||
| init_pattern = OrderedDict() | |||
| @@ -393,6 +420,10 @@ class SearchPath: | |||
| ) | |||
| self.new_pattern = context.found_pattern[self.hash_of_aft_repl] | |||
| self._created_modules = { | |||
| path.pattern.module_name: path.pattern for path in self.recursion_path | |||
| } | |||
| self._created_modules[self.pattern.module_name] = self.pattern | |||
| self.heuristic_v = self._heuristic_val() | |||
| self.actual_v = self._actual_val() | |||
| @@ -405,7 +436,7 @@ class SearchPath: | |||
| to recover the sequence. | |||
| """ | |||
| if self.pattern.pattern not in scope_name_mapping: | |||
| module_name = generate_module_name() | |||
| module_name = generate_module_name(self.pattern) | |||
| scope_name_mapping[self.pattern.pattern] = module_name | |||
| module_name_to_src[module_name] = self.pattern.pattern | |||
| else: | |||
| @@ -542,14 +573,26 @@ class SearchPath: | |||
| def evaluate_score(self): | |||
| """Evaluate path score.""" | |||
| return self.actual_v + self.heuristic_v | |||
| return .7 * self.actual_v + .3 * self.heuristic_v | |||
| def _cal_merged_module_length(self, ptn): | |||
| """Calculate module length.""" | |||
| ptn_len = 0 | |||
| for item in ptn.ptn_items: | |||
| if item in self._created_modules: | |||
| ptn_len += self._cal_merged_module_length(self._created_modules[item]) | |||
| continue | |||
| ptn_len += 1 | |||
| return ptn_len | |||
| def _heuristic_val(self): | |||
| """Calculate heuristic score of the path.""" | |||
| res = [] | |||
| for _, ptn in enumerate(self.new_pattern.items()): | |||
| res.append(ptn[1].count * ptn[1].ptn_length / len(self.topo_order_aft_repl)) | |||
| return sum(res) / len(res) | |||
| for ptn in self.new_pattern.items(): | |||
| res.append(ptn[1].count * self._cal_merged_module_length(ptn[1]) / context.get_sequence_length()) | |||
| if not res: | |||
| return 1.0 | |||
| return max(res) | |||
| def _actual_val(self): | |||
| """Calculate ground-truth score of the path.""" | |||
| @@ -580,6 +623,7 @@ class SearchPath: | |||
| chain.append(sub_module) | |||
| return "->".join(chain) | |||
| repr_str = f"{self.pattern.pattern}[{self.pattern.module_name}], H: {self.heuristic_v}, G: {self.actual_v}" | |||
| repr_str = f"{self.pattern.pattern}[{self.pattern.module_name}], " \ | |||
| f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}" | |||
| return repr_str | |||
| @@ -16,8 +16,8 @@ | |||
| from queue import PriorityQueue | |||
| from typing import Dict, List | |||
| from .common import context, DagGraph, gen_hash_key | |||
| from ..constant import MINI_FREQUENCY | |||
| from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT | |||
| from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||
| from ..third_party_graph.onnx_utils import BaseNode | |||
| from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | |||
| @@ -32,10 +32,13 @@ def _is_satisfied(path): | |||
| Returns: | |||
| bool, True or False. | |||
| """ | |||
| if len(path.recursion_path) == 2: | |||
| if len(path.recursion_path) > MAX_ITERATION_DEPTH: | |||
| return True | |||
| flag = [cur_pattern.count for _, cur_pattern in path.new_pattern.items()] | |||
| return float(sum(flag)) / len(flag) == 1 or path.actual_v >= 0.80 | |||
| if not path.new_pattern or max([p.count for _, p in path.new_pattern.items()]) < MINI_FREQUENCY: | |||
| return True | |||
| if path.evaluate_score() > SATISFIED_SCORE: | |||
| return True | |||
| return False | |||
| def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], | |||
| @@ -74,14 +77,16 @@ def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], | |||
| available_path.append(cur_path) | |||
| continue | |||
| if len(available_path) >= 8: | |||
| if len(available_path) >= ACCEPTABLE_RESULT_COUNT: | |||
| break | |||
| for _, cur_pattern in cur_path.new_pattern.items(): | |||
| if cur_pattern.count < MINI_FREQUENCY: | |||
| available_path.append(cur_path) | |||
| break | |||
| key = "/".join([cur_pattern.pattern, gen_hash_key(cur_topo_order)]) | |||
| continue | |||
| key = "/".join([f"{cur_pattern.pattern}[{cur_pattern.in_degree},{cur_pattern.out_degree}]", | |||
| gen_hash_key(cur_topo_order, without_module=True)]) | |||
| if key in context.visited: | |||
| continue | |||
| # c. create new SearchPath. | |||
| new_path = SearchPath(pattern=cur_pattern, sequence=cur_topo_order, prev_path=cur_path, | |||
| sub_graph_size=sub_graph_size) | |||
| @@ -107,18 +112,17 @@ def _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=4): | |||
| context.set_beam_width(beam_width) | |||
| def _get_top_1(available_path: list): | |||
| if len(available_path) <= 1: | |||
| return available_path | |||
| if not available_path: | |||
| return None | |||
| available_path = sorted(available_path, key=lambda x: x.actual_v, reverse=True) | |||
| return available_path[0] if available_path else None | |||
| return available_path[0] | |||
| topo_order = [node for _, (_, node) in enumerate(context.node_collection.items())] | |||
| context.set_sequence_length(len(topo_order)) | |||
| built_in_pattern = find_built_in_pattern(topo_order, init_dag) | |||
| pattern = generate_pattern(topo_order, dag=init_dag, sub_graph_size=sub_graph_size) | |||
| pattern.update(built_in_pattern) | |||
| found_path = _search(pattern, topo_order, init_graph=init_dag, | |||
| sub_graph_size=sub_graph_size) | |||
| found_path = _search(pattern, topo_order, init_graph=init_dag, sub_graph_size=2) | |||
| return _get_top_1(found_path) | |||
| @@ -183,7 +187,7 @@ def _retrieve_operators(module_path, module_dict): | |||
| str: module_name, operators in module. | |||
| """ | |||
| added_module = dict() | |||
| node_in_pattern = module_path.pattern.pattern.split('->') | |||
| node_in_pattern = module_path.pattern.ptn_items | |||
| node_list = [] | |||
| for node in node_in_pattern: | |||
| if module_dict.get(node): | |||
| @@ -192,9 +196,8 @@ def _retrieve_operators(module_path, module_dict): | |||
| added_module) | |||
| else: | |||
| node_list.append(node) | |||
| key = module_path.pattern.module_name | |||
| val = [f"{key}/{node}" for node in node_list] | |||
| return key, val | |||
| val = [f"{module_path.pattern.module_name}/{node}" for node in node_list] | |||
| return module_path.pattern.module_name, val | |||
| def _build_connection(loader): | |||
| @@ -216,6 +219,19 @@ def _build_connection(loader): | |||
| return dag | |||
| def flatten_graph(graph): | |||
| """ | |||
| Flatten graph into a sequence. | |||
| Args: | |||
| graph (DagGraph): DagGraph instance. | |||
| Returns: | |||
| list[str], corresponding scope name. | |||
| """ | |||
| return [f"Model/{node.op_type}" for _, node in graph.node_collection.items()] | |||
| def generate_scope_name(data_loader): | |||
| """ | |||
| Generate scope name according to computation graph. | |||
| @@ -227,6 +243,13 @@ def generate_scope_name(data_loader): | |||
| list[str], generated scope name. | |||
| """ | |||
| init_dag = _build_connection(data_loader) | |||
| result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) | |||
| topo_order_with_scope_name_list = _retrieve_scope_name(result) | |||
| try: | |||
| result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) | |||
| topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) | |||
| if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict): | |||
| topo_order_with_scope_name_list = flatten_graph(init_dag) | |||
| except (ValueError, IndexError, AttributeError, KeyError) as _: | |||
| topo_order_with_scope_name_list = flatten_graph(init_dag) | |||
| return topo_order_with_scope_name_list | |||