From a9080fc14a81a76047ddb667bb5a8a7bc7b4e87e Mon Sep 17 00:00:00 2001 From: liuchongming Date: Sun, 18 Oct 2020 21:35:15 +0800 Subject: [PATCH] Implement scope name generation --- .../graph_based_converter/constant.py | 2 + .../sub_graph_searcher/__init__.py | 3 + .../sub_graph_searcher/search_path.py | 574 ++++++++++++++++++ .../sub_graph_searcher/searcher.py | 227 +++++++ .../third_party_graph/onnx_utils.py | 19 +- 5 files changed, 816 insertions(+), 9 deletions(-) create mode 100644 mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py create mode 100644 mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 1258ed28..cc7ee4bc 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -27,6 +27,8 @@ FIRST_LEVEL_INDENT = BLANK_SYM * 4 SECOND_LEVEL_INDENT = BLANK_SYM * 8 NEW_LINE = "\n" +MINI_FREQUENCEY = 4 + ONNX_TYPE_INT = 2 ONNX_TYPE_INTS = 7 ONNX_TYPE_FLOAT = 1 diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py index 6f34a1ba..ff88c266 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py @@ -13,3 +13,6 @@ # limitations under the License. # ============================================================================== """Searcher of scope name.""" +from .searcher import generate_scope_name + +__all__ = ["generate_scope_name"] diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py new file mode 100644 index 00000000..30963991 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -0,0 +1,574 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Declare search path related.""" +import copy +import uuid +from typing import Dict, List, Callable, Union +from collections import OrderedDict +from .common import context, gen_hash_key, DagGraph +from ..third_party_graph.onnx_utils import OnnxNode, BaseNode + +scope_name_mapping = {} +module_name_to_src = {} +global_idx = 0 + + +class OptimizeRules: + """Define optimize rules.""" + CAN_NOT_BE_HEAD = {"Relu", "Add"} + HAS_MULTI_IPTS = {"Add", "Concat"} + ACTIVATION = {"Relu", "Tanh"} + + +def _is_connected(parent, child, dag): + """ + Whether two node are connected. + + Args: + parent (BaseNode): Parent node. + child (BaseNode): Child node. + dag (DagGraph): Graph instance. + + Returns: + bool, True or False. + """ + return parent.name in dag.precursor_table.get(child.name) + + +def _is_activation(node_type): + """ + Whether a node is activation function. + + Args: + node_type (str): Node type. + + Returns: + bool, True or False. + """ + return node_type in OptimizeRules.ACTIVATION + + +def _is_valid_pattern(pattern, dag): + """ + Whether a pattern is valid. + + Args: + pattern (dict): Pattern dict. + dag (DagGraph): Dag instance. + + Returns: + bool, True or False. + """ + if not pattern: + return False + first_op = dag.node_collection[list(pattern.keys())[0]].op_type + if len(pattern) == 1: + return False + if first_op in OptimizeRules.CAN_NOT_BE_HEAD: + return False + return True + + +def generate_module_name(): + """Generate module name.""" + global global_idx + name = f"Module{global_idx}" + global_idx += 1 + return name + + +def random_name(module_name): + """Generate node name.""" + return f"{module_name}_{str(uuid.uuid4()).split('-')[0]}" + + +class MergedONNXNode(BaseNode): + """Define merged onnx node.""" + + def __init__(self, name, module_name, ori_nodes): + super(MergedONNXNode, self).__init__(name, module_name) + self.nodes = ori_nodes + + def get_name(self): + return self.name + + def get_op(self): + return self.op_type + + +class Pattern: + """Define Pattern object.""" + + def __init__(self, pattern, pattern_length, in_degree, out_degree): + self.pattern = pattern + self.count = 0 + self.start_index = [] + self.end_index = [] + self.module_name = None + self.ptn_length = pattern_length + self.ptn_items = pattern.split("->") + self.in_degree = in_degree + self.out_degree = out_degree + + def insert(self, idx, seq_len): + """ + Insert a new position. + + Args: + idx (int): Start index. + seq_len (int): Pattern length. + """ + if idx in self.start_index: + return + self.start_index.append(idx) + self.end_index.append(idx + seq_len) + self.count += 1 + + def __str__(self): + """Override `str()` method.""" + return self.__repr__() + + def __repr__(self): + """Override `repr()` method.""" + return f"Ptn: {self.pattern}[" \ + f"{scope_name_mapping.get(self.pattern, 'Not init')}], " \ + f"count={self.count}" + + +def _find_idx(sequence: List[BaseNode], target: str, equal_func: Callable, + start_idx: int = 0, end_idx=None) -> int: + """ + Find matched result according to `equal_func` in [`start_idx`, `end_idx`). + + Args: + sequence (list): Raw topo sequence. + target (str): Target node name. + equal_func (Callable): Function to judge whether matched. + start_idx (int): Start index. + end_idx (int): End index. + + Returns: + int, index. + """ + not_found = -1 + if not sequence: + msg = f"Empty sequence is not supported." + raise ValueError(msg) + + end_idx = end_idx if end_idx else len(sequence) + for i in range(start_idx, end_idx): + if equal_func(sequence[i], target): + return i + return not_found + + +def _match(x: OnnxNode, y: str): + """ + Match func. + + Args: + x (OnnxNode): Node instance. + y (int): To be compared value. + """ + return x.name == y + + +def _get_pattern_degree(sequence: Union[OrderedDict, dict, list], + dag: DagGraph): + """ + Get degree of the pattern. + + Args: + sequence (Union[OrderedDict, dict, list]): Pattern to calculate. + dag (DagGraph): Graph instance. + + Returns: + tuple[int, int], in degree and out degree. + """ + in_node = set() + out_node = set() + node_in_seq = set() + items = sequence if isinstance(sequence, list) else sequence.keys() + for _, item in enumerate(items): + item = item.name if not isinstance(item, str) else item + for ipt in dag.precursor_table[item]: + in_node.add(ipt) + for opt in dag.successor_table[item]: + out_node.add(opt) + node_in_seq.add(item) + in_degree = len(in_node - node_in_seq) + out_degree = len(out_node - node_in_seq) + return in_degree, out_degree + + +def _find_pattern_tail(sequence: List[BaseNode], pattern: Dict[str, str], tail_idx: int, dag: DagGraph): + """ + Supply tail of the pattern sequence. + + Args: + sequence (list): Raw sequence. + pattern (dict[str, str]): Pattern to be supplied. + tail_idx (int): The position where pattern ends. + dag (DagGraph): Graph object. + + Returns: + int, tail index in the sequence. + """ + tail_append_idx = -1 + pattern_len = len(pattern) + for j, node_name in enumerate(pattern): + if len(dag.successor_table[node_name]) <= 1: + continue + if j == pattern_len - 1: + # If last node of the pattern has multi-successors, + # then ignore it. + continue + for nd_name in dag.successor_table[node_name]: + if nd_name not in pattern: + fd_idx = _find_idx(sequence=sequence, target=nd_name, + equal_func=_match, start_idx=tail_idx) + tail_append_idx = max(fd_idx, tail_append_idx) + return tail_append_idx + + +def _supply_sequence(sequence: List[BaseNode], pattern: Dict[str, str], offset: int, dag: DagGraph): + """ + Supply sequence from front to end. + + Args: + sequence (list): Raw sequence. + pattern (dict[str, str]): Pattern to be supplied. + offset (int): The position where pattern ends. + dag (DagGraph): Graph object. + + Returns: + tuple[dict, tuple[int, int]], found pattern and corresponding position. + """ + found_sequence = pattern + tail_idx = offset + ori_seq_len = len(found_sequence) + while True: + tail_idx = _find_pattern_tail(sequence=sequence, pattern=found_sequence, + tail_idx=tail_idx, dag=dag) + if tail_idx == -1: + break + for j in range(offset + 1, tail_idx + 1): + # If tail_append_idx==-1, this loop will not be executed. + node_obj = dag.node_collection[sequence[j].name] + found_sequence[node_obj.name] = node_obj.op_type + + if offset + len(found_sequence) - ori_seq_len + 1 >= len(sequence): + return found_sequence, (offset - ori_seq_len + 1, + offset + len(found_sequence) - ori_seq_len) + + # If the next node after `found_sequence` is an activation and + # has only one edge from `found_sequence`, then link it + # to `found_sequence`. + last_node = sequence[offset + len(found_sequence) - ori_seq_len] + next_node = sequence[offset + len(found_sequence) - ori_seq_len + 1] + if _is_activation(next_node.op_type) and _is_connected(last_node, next_node, dag): + found_sequence[next_node.name] = next_node.op_type + + return found_sequence, (offset - ori_seq_len + 1, + offset + len(found_sequence) - ori_seq_len) + + +def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, + sub_graph_size: int = 2) -> Dict[str, Pattern]: + """ + Use self-adaptive sliding window to found sub-graph. + + Args: + dag (DagGraph): Graph object. + topo_order (list): Topo sequence. + sub_graph_size (int): Mini sub-graph size. + + Returns: + dict[str, Pattern], found pattern. + """ + pattern = {} + cur_idx, total_len = 0, len(topo_order) + while cur_idx < total_len: + if cur_idx < sub_graph_size - 1: + cur_idx += 1 + continue + cur_node = topo_order[cur_idx] + init_pattern = OrderedDict() + prev_node = None + jump_step = 0 + for j in range(sub_graph_size - 1, 0, -1): + node_obj = dag.node_collection.get(topo_order[cur_idx - j].name) + # If current node is not child of `prev_node`, + # then break it. The topo order got from ONNX has a + # good feature, nodes belonging to one scope would be together. + # Thus, we can do linear scan on topo order. + if j != sub_graph_size - 1 and prev_node not in dag.precursor_table.get(topo_order[cur_idx - j].name): + jump_step = j + 1 + break + init_pattern[node_obj.name] = node_obj.op_type + prev_node = topo_order[cur_idx - j].name + + if jump_step == 0: + init_pattern[cur_node.name] = cur_node.op_type + + if not _is_valid_pattern(init_pattern, dag): + # in OptimizeRules.CAN_NOT_BE_HEAD: + # If pattern starts with "ReLU", then pass it. + cur_idx += 1 + continue + + found_sequence, _ = _supply_sequence(sequence=topo_order, + pattern=init_pattern, + offset=cur_idx - jump_step, + dag=dag) + + in_degree, out_degree = _get_pattern_degree(found_sequence, dag) + ptn = '->'.join(found_sequence.values()) + ptn_key = f"{ptn}[{in_degree}, {out_degree}]" + if ptn_key not in pattern: + pattern[ptn_key] = Pattern(ptn, len(found_sequence), + in_degree=in_degree, + out_degree=out_degree) + + pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence)) + cur_idx = cur_idx + 1 + + return pattern + + +class SearchPath: + """ + Use SearchPath to store the search path. + + Args: + pattern (Pattern): Pattern instance to be matched. + sequence (list): A list of nodes in topological order. + prev_path (SearchPath): Previous search path instance. + graph (DagGraph): Graph instance. + sub_graph_size (int): Mini sub-graph size to search. + + """ + + def __init__(self, pattern, sequence: List[BaseNode], prev_path=None, + graph=None, sub_graph_size: int = 2): + self.pattern = pattern + self.graph = copy.copy(prev_path.graph) if prev_path is not None \ + else copy.copy(graph) + self.recursion_path = prev_path.recursion_path[:] \ + if prev_path is not None else list() + if prev_path is not None: + self.recursion_path.append(prev_path) + + self.topo_order_bef_repl = sequence + self.topo_order_aft_repl, self.inverted_index = self._create_new_order() + self.node_collection = dict() + self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl) + if self.hash_of_aft_repl not in context.found_pattern: + context.found_pattern[self.hash_of_aft_repl] = context.sort_with_beam( + generate_pattern(self.topo_order_aft_repl, dag=self.graph, sub_graph_size=sub_graph_size) + ) + + self.new_pattern = context.found_pattern[self.hash_of_aft_repl] + self.heuristic_v = self._heuristic_val() + self.actual_v = self._actual_val() + + def _create_new_order(self): + """ + Replace sequence with pattern. + + Returns: + tuple[list, dict], topo sequence and inverted index + to recover the sequence. + """ + global scope_name_mapping + if self.pattern.pattern not in scope_name_mapping: + module_name = generate_module_name() + scope_name_mapping[self.pattern.pattern] = module_name + module_name_to_src[module_name] = self.pattern.pattern + else: + module_name = scope_name_mapping[self.pattern.pattern] + self.pattern.module_name = module_name + topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) + return topo_order, inverted_index + + def replace_sub_graph_completely(self, pattern: Pattern, + original_topo_order: List[BaseNode]): + """ + Replace sequence with pattern. + + Match pattern from scratch. + + Notes: + Bugs here, replace the sub-graph in sequence will have multi-path. + However, we use greedy-strategy, replace the pattern that appear at front, + and only keep one path. + + Args: + pattern (Pattern): Pattern to be used. + original_topo_order (list): Sequence. + + Returns: + tuple[list, dict], topo sequence and inverted index + to recover the sequence. + """ + inverted_index = {} + topo_order = [] + path_length = 0 + index = 0 + pattern_len = pattern.ptn_length + ori_seq_len = len(original_topo_order) + while index < ori_seq_len: + if original_topo_order[index].op_type != pattern.ptn_items[0] or \ + ori_seq_len - index < pattern_len: + topo_order.append(original_topo_order[index]) + index += 1 + path_length += 1 + continue + + visited_node, j = [], 0 + matched = True + for j in range(pattern_len): + visited_node.append(original_topo_order[index + j]) + if original_topo_order[index + j].op_type != pattern.ptn_items[j]: + topo_order.extend(visited_node) + index += j + 1 + path_length += j + 1 + matched = False + break + + if not matched: + continue + + in_degree, out_degree = _get_pattern_degree(visited_node, self.graph) + if in_degree != pattern.in_degree or out_degree != pattern.out_degree: + topo_order.extend(visited_node) + index += j + 1 + path_length += j + 1 + continue + + inverted_index[path_length] = [j + index for j in range(pattern_len)] + new_node = MergedONNXNode(name=random_name(pattern.module_name), + module_name=pattern.module_name, + ori_nodes=visited_node[:]) + self._reconnect(new_node) + self.graph.node_collection[new_node.name] = new_node + topo_order.append(new_node) + path_length += 1 + index += pattern_len + + return topo_order, inverted_index + + def _reconnect(self, merged_node): + """ + Re-connect merged_node with its precursor and successor nodes. + + Args: + merged_node (MergedONNXNode): Merged node. + """ + in_node, out_node = [], [] + node_in_seq = [item.name for item in merged_node.nodes] + for _, item in enumerate(merged_node.nodes): + item = item.name if not isinstance(item, str) else item + for ipt in self.graph.precursor_table[item]: + if ipt not in node_in_seq: + in_node.append(ipt) + for opt in self.graph.successor_table[item]: + if opt not in node_in_seq: + out_node.append(opt) + self.graph.precursor_table[merged_node.name] = in_node + self._relink_precursor(merged_node, in_node, node_in_seq) + self._relink_successor(merged_node, out_node, node_in_seq) + + def _relink_precursor(self, merged_node, in_node, node_in_seq): + """ + Relink node to precursor. + + Args: + merged_node (MergedONNXNode): Merged node instance. + in_node (list): In nodes list. + node_in_seq (list): Node name in sequence. + """ + # Add current node to precursor table. + self.graph.precursor_table[merged_node.name] = in_node + # Link the precursor to current node. + for p_nd in in_node: + scsr_nodes = self.graph.successor_table[p_nd].copy() + for i, nd_name in enumerate(scsr_nodes): + if nd_name in node_in_seq: + scsr_nodes[i] = merged_node.name + self.graph.successor_table[p_nd] = scsr_nodes + + def _relink_successor(self, merged_node, out_node, node_in_seq): + """ + Relink node to successor. + + Args: + merged_node (MergedONNXNode): Merged node. + out_node (list): Out nodes. + node_in_seq (list): Node name in sequence. + """ + # Add current node to successor table. + self.graph.successor_table[merged_node.name] = out_node + # Link successor to current node. + for s_nd in out_node: + p_nodes = self.graph.precursor_table[s_nd].copy() + for i, nd_name in enumerate(p_nodes): + if nd_name in node_in_seq: + p_nodes[i] = merged_node.name + self.graph.precursor_table[s_nd] = p_nodes + + def evaluate_score(self): + """Evaluate path score.""" + return self.actual_v + self.heuristic_v + + def _heuristic_val(self): + """Calculate heuristic score of the path.""" + res = [] + for _, ptn in enumerate(self.new_pattern.items()): + res.append(ptn[1].count * ptn[1].ptn_length / len(self.topo_order_aft_repl)) + return sum(res) / len(res) + + def _actual_val(self): + """Calculate ground-truth score of the path.""" + return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length() + + def __lt__(self, other): + """Override `<` operator.""" + return self.evaluate_score() > other.evaluate_score() + + def __eq__(self, other): + """Override `==` operator.""" + return self.evaluate_score() == other.evaluate_score() + + def __str__(self): + """Override `str()` method.""" + return self.__repr__() + + def __repr__(self): + """Override `repr()` method.""" + + def _dfs(module_name): + chain = [] + src = module_name_to_src[module_name] + for sub_module in src.split("->"): + if sub_module in module_name_to_src: + chain.append(_dfs(sub_module)) + else: + chain.append(sub_module) + return "->".join(chain) + + repr_str = f"{self.pattern.pattern}[{self.pattern.module_name}], H: {self.heuristic_v}, G: {self.actual_v}" + + return repr_str diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py new file mode 100644 index 00000000..8eb6e256 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -0,0 +1,227 @@ +# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Definition of search entry.""" +from queue import PriorityQueue +from typing import Dict, List + +from .common import context, DagGraph, gen_hash_key +from ..constant import MINI_FREQUENCEY +from ..third_party_graph.onnx_utils import BaseNode +from .search_path import SearchPath, Pattern, generate_pattern + +# Hold module name of current graph. +module_name_mgr = dict() + + +def _is_satisfied(path): + """ + Whether current path is satisfied. + + Args: + path (SearchPath): A SearchPath instance. + + Returns: + bool, True or False. + """ + if len(path.recursion_path) == 2: + return True + flag = [cur_pattern.count for _, cur_pattern in path.new_pattern.items()] + return float(sum(flag)) / len(flag) == 1 or path.actual_v >= 0.80 + + +def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode], + init_graph, sub_graph_size: int = 2) -> List[SearchPath]: + """ + Search base on merged graph, until all frequency is 1. + + Args: + init_pattern (dict): Init pattern to be replaced. + init_topo_order (list): Init topo sequence. + init_graph (DagGraph): Graph instance. + sub_graph_size (int): Min sub-graph size. + + Returns: + list, available path. + """ + # 1. Sort the pattern by frequency. + sorted_pattern = context.sort_with_beam(init_pattern) + # 2. Put pattern into queue. + queue = PriorityQueue() + for _, pattern_inst in sorted_pattern.items(): + queue.put( + SearchPath(pattern=pattern_inst, sequence=init_topo_order, + graph=init_graph, + sub_graph_size=sub_graph_size), + block=False + ) + + available_path = [] + while not queue.empty(): + # a. replace pattern in current topo order. + cur_path = queue.get(block=False) + cur_topo_order = cur_path.topo_order_aft_repl + # b. generate new pattern based on replaced topo order. + if _is_satisfied(cur_path): + available_path.append(cur_path) + continue + + if len(available_path) >= 8: + break + + for _, cur_pattern in cur_path.new_pattern.items(): + if cur_pattern.count < MINI_FREQUENCEY: + available_path.append(cur_path) + break + key = "/".join([cur_pattern.pattern, gen_hash_key(cur_topo_order)]) + # c. create new SearchPath. + new_path = SearchPath(pattern=cur_pattern, sequence=cur_topo_order, prev_path=cur_path, + sub_graph_size=sub_graph_size) + context.visited.add(key) + # d. put it into heap to sort. + queue.put(new_path, block=False) + + return available_path + + +def _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=4): + """ + Sub-graph matching. + + Args: + init_dag (DagGraph): Graph instance. + beam_width (int): Beam width used to prune search path. + sub_graph_size (int): Mini sub-graph size to find. + + Returns: + SearchPath, found path. + """ + context.set_beam_width(beam_width) + + def _get_top_1(available_path: list): + if len(available_path) <= 1: + return available_path + available_path = sorted(available_path, key=lambda x: x.actual_v, reverse=True) + return available_path[0] if available_path else None + + topo_order = [node for _, (_, node) in enumerate(context.node_collection.items())] + context.set_sequence_length(len(topo_order)) + pattern = generate_pattern(topo_order, dag=init_dag, sub_graph_size=sub_graph_size) + found_path = _search(pattern, topo_order, init_graph=init_dag, + sub_graph_size=sub_graph_size) + return _get_top_1(found_path) + + +def _retrieve_scope_name(found_path): + """ + Retrieve scope name. + + Args: + found_path: Found path. + """ + module_dict = dict() + for module_path in found_path.recursion_path: + key, val = _retrieve_operators(module_path, module_dict) + module_dict[key] = val + if found_path.pattern: + key, val = _retrieve_operators(found_path, module_dict) + module_dict[key] = val + + topo_order_with_scope_name = [] + for node in found_path.topo_order_aft_repl: + if module_dict.get(node.op_type): + topo_order_with_scope_name += [f"Model/{item}" for item in _scope_name_deduplication( + node.op_type, module_dict[node.op_type])] + else: + topo_order_with_scope_name.append(f"Model/{node.op_type}") + return topo_order_with_scope_name + + +def _scope_name_deduplication(key, scope_names) -> list: + """ + Scope name deduplication. + + Args: + scope_names (list): Scope names. + + Returns: + list, renamed scope name. + """ + result = [] + if key not in module_name_mgr: + module_name_mgr[key] = 0 + for item in scope_names: + item = item.replace(key, f"{key}_{module_name_mgr.get(key)}") + result.append(item) + module_name_mgr[key] += 1 + return result + + +def _retrieve_operators(module_path, module_dict): + """ + Retrieve operators from path. + + Args: + module_path(SearchPath): module path. + module_dict(dict): module dictionary. + + Returns: + str: module_name, operators in module. + """ + global module_name_mgr + node_in_pattern = module_path.pattern.pattern.split('->') + node_list = [] + for node in node_in_pattern: + if module_dict.get(node): + node_list += module_dict[node] + else: + node_list.append(node) + key = module_path.pattern.module_name + val = [f"{key}/{node}" for node in node_list] + return key, val + + +def _build_connection(loader): + """ + Build dag graph. + + Args: + loader (OnnxDataLoader): Dataloader. + """ + context.set_init_node_collection(loader.nodes_dict) + # Output name is not same with node name + for node_name, node in loader.nodes_dict.items(): + context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) + context.successor_table[node_name] = list(node.get_successor_dict().keys()) + + dag = DagGraph(nodes=context.node_collection.copy(), + precursor=context.precursor_table.copy(), + successor=context.successor_table.copy()) + return dag + + +def generate_scope_name(data_loader): + """ + Generate scope name according to computation graph. + + Args: + data_loader (OnnxDataLoader): Data loader instance. + + Returns: + list[str], generated scope name. + """ + init_dag = _build_connection(data_loader) + result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) + topo_order_with_scope_name_list = _retrieve_scope_name(result) + return topo_order_with_scope_name_list diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index daee5791..bf01ee0e 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Define ONNX related opertions.""" +"""Define ONNX related operations.""" import re import abc from collections import OrderedDict from mindinsight.mindconverter.common.log import logger as log -from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING,\ +from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT @@ -27,7 +27,7 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None """ Convert Tensorflow model to ONNX model. - Note: shape overide is supported by tf2onnx but we + Note: shape override is supported by tf2onnx but we have not supported yet. Args: @@ -87,6 +87,7 @@ class OnnxTensor: raw_tensor (onnx.TensorProto): onnx.TensorProto instance. """ import onnx + def __init__(self, raw_tensor): self.raw_tensor = raw_tensor self.name = raw_tensor.name @@ -279,8 +280,9 @@ class OnnxDataLoader: self.inferred_model = onnx.shape_inference.infer_shapes(self.model) def _parse_value_info(self): # no input node & output node - """Parse onnx defined value_info class attribtues""" + """Parse onnx defined value_info class attributes.""" import onnx + def _parse_value_info_re(i): """ Parse the value_info by regular expression @@ -297,7 +299,7 @@ class OnnxDataLoader: i_type = group_match.group('type') i_dim_str = group_match.group('dim_str') - return (i_name, i_type, i_dim_str) + return i_name, i_type, i_dim_str if not self.inferred_model: return @@ -333,18 +335,17 @@ class OnnxDataLoader: This function has a prerequisite of the shape inference. """ for (node_name, (_, shape_str)) in self.value_info_dict.items(): - l = [] + lst = [] # split shape by 'x' shape_list = shape_str.split('x') # replace unknown shape by '-1' for s in shape_list: if 'unk' in s: s = '-1' - # convert str to int s = int(s) - l.append(s) - self.node_output_shape_dict[node_name] = l + lst.append(s) + self.node_output_shape_dict[node_name] = lst def get_node(self, node_name): """Get the OnnxNode instance by node name."""