Browse Source

Implement scope name generation

tags/v1.1.0
liuchongming 5 years ago
parent
commit
a9080fc14a
5 changed files with 816 additions and 9 deletions
  1. +2
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  2. +3
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py
  3. +574
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  4. +227
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  5. +10
    -9
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 2
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -27,6 +27,8 @@ FIRST_LEVEL_INDENT = BLANK_SYM * 4
SECOND_LEVEL_INDENT = BLANK_SYM * 8
NEW_LINE = "\n"

MINI_FREQUENCEY = 4

ONNX_TYPE_INT = 2
ONNX_TYPE_INTS = 7
ONNX_TYPE_FLOAT = 1


+ 3
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py View File

@@ -13,3 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Searcher of scope name."""
from .searcher import generate_scope_name

__all__ = ["generate_scope_name"]

+ 574
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -0,0 +1,574 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Declare search path related."""
import copy
import uuid
from typing import Dict, List, Callable, Union
from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph
from ..third_party_graph.onnx_utils import OnnxNode, BaseNode

scope_name_mapping = {}
module_name_to_src = {}
global_idx = 0


class OptimizeRules:
"""Define optimize rules."""
CAN_NOT_BE_HEAD = {"Relu", "Add"}
HAS_MULTI_IPTS = {"Add", "Concat"}
ACTIVATION = {"Relu", "Tanh"}


def _is_connected(parent, child, dag):
"""
Whether two node are connected.

Args:
parent (BaseNode): Parent node.
child (BaseNode): Child node.
dag (DagGraph): Graph instance.

Returns:
bool, True or False.
"""
return parent.name in dag.precursor_table.get(child.name)


def _is_activation(node_type):
"""
Whether a node is activation function.

Args:
node_type (str): Node type.

Returns:
bool, True or False.
"""
return node_type in OptimizeRules.ACTIVATION


def _is_valid_pattern(pattern, dag):
"""
Whether a pattern is valid.

Args:
pattern (dict): Pattern dict.
dag (DagGraph): Dag instance.

Returns:
bool, True or False.
"""
if not pattern:
return False
first_op = dag.node_collection[list(pattern.keys())[0]].op_type
if len(pattern) == 1:
return False
if first_op in OptimizeRules.CAN_NOT_BE_HEAD:
return False
return True


def generate_module_name():
"""Generate module name."""
global global_idx
name = f"Module{global_idx}"
global_idx += 1
return name


def random_name(module_name):
"""Generate node name."""
return f"{module_name}_{str(uuid.uuid4()).split('-')[0]}"


class MergedONNXNode(BaseNode):
"""Define merged onnx node."""

def __init__(self, name, module_name, ori_nodes):
super(MergedONNXNode, self).__init__(name, module_name)
self.nodes = ori_nodes

def get_name(self):
return self.name

def get_op(self):
return self.op_type


class Pattern:
"""Define Pattern object."""

def __init__(self, pattern, pattern_length, in_degree, out_degree):
self.pattern = pattern
self.count = 0
self.start_index = []
self.end_index = []
self.module_name = None
self.ptn_length = pattern_length
self.ptn_items = pattern.split("->")
self.in_degree = in_degree
self.out_degree = out_degree

def insert(self, idx, seq_len):
"""
Insert a new position.

Args:
idx (int): Start index.
seq_len (int): Pattern length.
"""
if idx in self.start_index:
return
self.start_index.append(idx)
self.end_index.append(idx + seq_len)
self.count += 1

def __str__(self):
"""Override `str()` method."""
return self.__repr__()

def __repr__(self):
"""Override `repr()` method."""
return f"Ptn: {self.pattern}[" \
f"{scope_name_mapping.get(self.pattern, 'Not init')}], " \
f"count={self.count}"


def _find_idx(sequence: List[BaseNode], target: str, equal_func: Callable,
start_idx: int = 0, end_idx=None) -> int:
"""
Find matched result according to `equal_func` in [`start_idx`, `end_idx`).

Args:
sequence (list): Raw topo sequence.
target (str): Target node name.
equal_func (Callable): Function to judge whether matched.
start_idx (int): Start index.
end_idx (int): End index.

Returns:
int, index.
"""
not_found = -1
if not sequence:
msg = f"Empty sequence is not supported."
raise ValueError(msg)

end_idx = end_idx if end_idx else len(sequence)
for i in range(start_idx, end_idx):
if equal_func(sequence[i], target):
return i
return not_found


def _match(x: OnnxNode, y: str):
"""
Match func.

Args:
x (OnnxNode): Node instance.
y (int): To be compared value.
"""
return x.name == y


def _get_pattern_degree(sequence: Union[OrderedDict, dict, list],
dag: DagGraph):
"""
Get degree of the pattern.

Args:
sequence (Union[OrderedDict, dict, list]): Pattern to calculate.
dag (DagGraph): Graph instance.

Returns:
tuple[int, int], in degree and out degree.
"""
in_node = set()
out_node = set()
node_in_seq = set()
items = sequence if isinstance(sequence, list) else sequence.keys()
for _, item in enumerate(items):
item = item.name if not isinstance(item, str) else item
for ipt in dag.precursor_table[item]:
in_node.add(ipt)
for opt in dag.successor_table[item]:
out_node.add(opt)
node_in_seq.add(item)
in_degree = len(in_node - node_in_seq)
out_degree = len(out_node - node_in_seq)
return in_degree, out_degree


def _find_pattern_tail(sequence: List[BaseNode], pattern: Dict[str, str], tail_idx: int, dag: DagGraph):
"""
Supply tail of the pattern sequence.

Args:
sequence (list): Raw sequence.
pattern (dict[str, str]): Pattern to be supplied.
tail_idx (int): The position where pattern ends.
dag (DagGraph): Graph object.

Returns:
int, tail index in the sequence.
"""
tail_append_idx = -1
pattern_len = len(pattern)
for j, node_name in enumerate(pattern):
if len(dag.successor_table[node_name]) <= 1:
continue
if j == pattern_len - 1:
# If last node of the pattern has multi-successors,
# then ignore it.
continue
for nd_name in dag.successor_table[node_name]:
if nd_name not in pattern:
fd_idx = _find_idx(sequence=sequence, target=nd_name,
equal_func=_match, start_idx=tail_idx)
tail_append_idx = max(fd_idx, tail_append_idx)
return tail_append_idx


def _supply_sequence(sequence: List[BaseNode], pattern: Dict[str, str], offset: int, dag: DagGraph):
"""
Supply sequence from front to end.

Args:
sequence (list): Raw sequence.
pattern (dict[str, str]): Pattern to be supplied.
offset (int): The position where pattern ends.
dag (DagGraph): Graph object.

Returns:
tuple[dict, tuple[int, int]], found pattern and corresponding position.
"""
found_sequence = pattern
tail_idx = offset
ori_seq_len = len(found_sequence)
while True:
tail_idx = _find_pattern_tail(sequence=sequence, pattern=found_sequence,
tail_idx=tail_idx, dag=dag)
if tail_idx == -1:
break
for j in range(offset + 1, tail_idx + 1):
# If tail_append_idx==-1, this loop will not be executed.
node_obj = dag.node_collection[sequence[j].name]
found_sequence[node_obj.name] = node_obj.op_type

if offset + len(found_sequence) - ori_seq_len + 1 >= len(sequence):
return found_sequence, (offset - ori_seq_len + 1,
offset + len(found_sequence) - ori_seq_len)

# If the next node after `found_sequence` is an activation and
# has only one edge from `found_sequence`, then link it
# to `found_sequence`.
last_node = sequence[offset + len(found_sequence) - ori_seq_len]
next_node = sequence[offset + len(found_sequence) - ori_seq_len + 1]
if _is_activation(next_node.op_type) and _is_connected(last_node, next_node, dag):
found_sequence[next_node.name] = next_node.op_type

return found_sequence, (offset - ori_seq_len + 1,
offset + len(found_sequence) - ori_seq_len)


def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
sub_graph_size: int = 2) -> Dict[str, Pattern]:
"""
Use self-adaptive sliding window to found sub-graph.

Args:
dag (DagGraph): Graph object.
topo_order (list): Topo sequence.
sub_graph_size (int): Mini sub-graph size.

Returns:
dict[str, Pattern], found pattern.
"""
pattern = {}
cur_idx, total_len = 0, len(topo_order)
while cur_idx < total_len:
if cur_idx < sub_graph_size - 1:
cur_idx += 1
continue
cur_node = topo_order[cur_idx]
init_pattern = OrderedDict()
prev_node = None
jump_step = 0
for j in range(sub_graph_size - 1, 0, -1):
node_obj = dag.node_collection.get(topo_order[cur_idx - j].name)
# If current node is not child of `prev_node`,
# then break it. The topo order got from ONNX has a
# good feature, nodes belonging to one scope would be together.
# Thus, we can do linear scan on topo order.
if j != sub_graph_size - 1 and prev_node not in dag.precursor_table.get(topo_order[cur_idx - j].name):
jump_step = j + 1
break
init_pattern[node_obj.name] = node_obj.op_type
prev_node = topo_order[cur_idx - j].name

if jump_step == 0:
init_pattern[cur_node.name] = cur_node.op_type

if not _is_valid_pattern(init_pattern, dag):
# in OptimizeRules.CAN_NOT_BE_HEAD:
# If pattern starts with "ReLU", then pass it.
cur_idx += 1
continue

found_sequence, _ = _supply_sequence(sequence=topo_order,
pattern=init_pattern,
offset=cur_idx - jump_step,
dag=dag)

in_degree, out_degree = _get_pattern_degree(found_sequence, dag)
ptn = '->'.join(found_sequence.values())
ptn_key = f"{ptn}[{in_degree}, {out_degree}]"
if ptn_key not in pattern:
pattern[ptn_key] = Pattern(ptn, len(found_sequence),
in_degree=in_degree,
out_degree=out_degree)

pattern[ptn_key].insert(cur_idx - sub_graph_size + 1, len(found_sequence))
cur_idx = cur_idx + 1

return pattern


class SearchPath:
"""
Use SearchPath to store the search path.

Args:
pattern (Pattern): Pattern instance to be matched.
sequence (list): A list of nodes in topological order.
prev_path (SearchPath): Previous search path instance.
graph (DagGraph): Graph instance.
sub_graph_size (int): Mini sub-graph size to search.

"""

def __init__(self, pattern, sequence: List[BaseNode], prev_path=None,
graph=None, sub_graph_size: int = 2):
self.pattern = pattern
self.graph = copy.copy(prev_path.graph) if prev_path is not None \
else copy.copy(graph)
self.recursion_path = prev_path.recursion_path[:] \
if prev_path is not None else list()
if prev_path is not None:
self.recursion_path.append(prev_path)

self.topo_order_bef_repl = sequence
self.topo_order_aft_repl, self.inverted_index = self._create_new_order()
self.node_collection = dict()
self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl)
if self.hash_of_aft_repl not in context.found_pattern:
context.found_pattern[self.hash_of_aft_repl] = context.sort_with_beam(
generate_pattern(self.topo_order_aft_repl, dag=self.graph, sub_graph_size=sub_graph_size)
)

self.new_pattern = context.found_pattern[self.hash_of_aft_repl]
self.heuristic_v = self._heuristic_val()
self.actual_v = self._actual_val()

def _create_new_order(self):
"""
Replace sequence with pattern.

Returns:
tuple[list, dict], topo sequence and inverted index
to recover the sequence.
"""
global scope_name_mapping
if self.pattern.pattern not in scope_name_mapping:
module_name = generate_module_name()
scope_name_mapping[self.pattern.pattern] = module_name
module_name_to_src[module_name] = self.pattern.pattern
else:
module_name = scope_name_mapping[self.pattern.pattern]
self.pattern.module_name = module_name
topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl)
return topo_order, inverted_index

def replace_sub_graph_completely(self, pattern: Pattern,
original_topo_order: List[BaseNode]):
"""
Replace sequence with pattern.

Match pattern from scratch.

Notes:
Bugs here, replace the sub-graph in sequence will have multi-path.
However, we use greedy-strategy, replace the pattern that appear at front,
and only keep one path.

Args:
pattern (Pattern): Pattern to be used.
original_topo_order (list): Sequence.

Returns:
tuple[list, dict], topo sequence and inverted index
to recover the sequence.
"""
inverted_index = {}
topo_order = []
path_length = 0
index = 0
pattern_len = pattern.ptn_length
ori_seq_len = len(original_topo_order)
while index < ori_seq_len:
if original_topo_order[index].op_type != pattern.ptn_items[0] or \
ori_seq_len - index < pattern_len:
topo_order.append(original_topo_order[index])
index += 1
path_length += 1
continue

visited_node, j = [], 0
matched = True
for j in range(pattern_len):
visited_node.append(original_topo_order[index + j])
if original_topo_order[index + j].op_type != pattern.ptn_items[j]:
topo_order.extend(visited_node)
index += j + 1
path_length += j + 1
matched = False
break

if not matched:
continue

in_degree, out_degree = _get_pattern_degree(visited_node, self.graph)
if in_degree != pattern.in_degree or out_degree != pattern.out_degree:
topo_order.extend(visited_node)
index += j + 1
path_length += j + 1
continue

inverted_index[path_length] = [j + index for j in range(pattern_len)]
new_node = MergedONNXNode(name=random_name(pattern.module_name),
module_name=pattern.module_name,
ori_nodes=visited_node[:])
self._reconnect(new_node)
self.graph.node_collection[new_node.name] = new_node
topo_order.append(new_node)
path_length += 1
index += pattern_len

return topo_order, inverted_index

def _reconnect(self, merged_node):
"""
Re-connect merged_node with its precursor and successor nodes.

Args:
merged_node (MergedONNXNode): Merged node.
"""
in_node, out_node = [], []
node_in_seq = [item.name for item in merged_node.nodes]
for _, item in enumerate(merged_node.nodes):
item = item.name if not isinstance(item, str) else item
for ipt in self.graph.precursor_table[item]:
if ipt not in node_in_seq:
in_node.append(ipt)
for opt in self.graph.successor_table[item]:
if opt not in node_in_seq:
out_node.append(opt)
self.graph.precursor_table[merged_node.name] = in_node
self._relink_precursor(merged_node, in_node, node_in_seq)
self._relink_successor(merged_node, out_node, node_in_seq)

def _relink_precursor(self, merged_node, in_node, node_in_seq):
"""
Relink node to precursor.

Args:
merged_node (MergedONNXNode): Merged node instance.
in_node (list): In nodes list.
node_in_seq (list): Node name in sequence.
"""
# Add current node to precursor table.
self.graph.precursor_table[merged_node.name] = in_node
# Link the precursor to current node.
for p_nd in in_node:
scsr_nodes = self.graph.successor_table[p_nd].copy()
for i, nd_name in enumerate(scsr_nodes):
if nd_name in node_in_seq:
scsr_nodes[i] = merged_node.name
self.graph.successor_table[p_nd] = scsr_nodes

def _relink_successor(self, merged_node, out_node, node_in_seq):
"""
Relink node to successor.

Args:
merged_node (MergedONNXNode): Merged node.
out_node (list): Out nodes.
node_in_seq (list): Node name in sequence.
"""
# Add current node to successor table.
self.graph.successor_table[merged_node.name] = out_node
# Link successor to current node.
for s_nd in out_node:
p_nodes = self.graph.precursor_table[s_nd].copy()
for i, nd_name in enumerate(p_nodes):
if nd_name in node_in_seq:
p_nodes[i] = merged_node.name
self.graph.precursor_table[s_nd] = p_nodes

def evaluate_score(self):
"""Evaluate path score."""
return self.actual_v + self.heuristic_v

def _heuristic_val(self):
"""Calculate heuristic score of the path."""
res = []
for _, ptn in enumerate(self.new_pattern.items()):
res.append(ptn[1].count * ptn[1].ptn_length / len(self.topo_order_aft_repl))
return sum(res) / len(res)

def _actual_val(self):
"""Calculate ground-truth score of the path."""
return (context.get_sequence_length() - len(self.topo_order_aft_repl)) / context.get_sequence_length()

def __lt__(self, other):
"""Override `<` operator."""
return self.evaluate_score() > other.evaluate_score()

def __eq__(self, other):
"""Override `==` operator."""
return self.evaluate_score() == other.evaluate_score()

def __str__(self):
"""Override `str()` method."""
return self.__repr__()

def __repr__(self):
"""Override `repr()` method."""

def _dfs(module_name):
chain = []
src = module_name_to_src[module_name]
for sub_module in src.split("->"):
if sub_module in module_name_to_src:
chain.append(_dfs(sub_module))
else:
chain.append(sub_module)
return "->".join(chain)

repr_str = f"{self.pattern.pattern}[{self.pattern.module_name}], H: {self.heuristic_v}, G: {self.actual_v}"

return repr_str

+ 227
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -0,0 +1,227 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Definition of search entry."""
from queue import PriorityQueue
from typing import Dict, List

from .common import context, DagGraph, gen_hash_key
from ..constant import MINI_FREQUENCEY
from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern

# Hold module name of current graph.
module_name_mgr = dict()


def _is_satisfied(path):
"""
Whether current path is satisfied.

Args:
path (SearchPath): A SearchPath instance.

Returns:
bool, True or False.
"""
if len(path.recursion_path) == 2:
return True
flag = [cur_pattern.count for _, cur_pattern in path.new_pattern.items()]
return float(sum(flag)) / len(flag) == 1 or path.actual_v >= 0.80


def _search(init_pattern: Dict[str, Pattern], init_topo_order: List[BaseNode],
init_graph, sub_graph_size: int = 2) -> List[SearchPath]:
"""
Search base on merged graph, until all frequency is 1.

Args:
init_pattern (dict): Init pattern to be replaced.
init_topo_order (list): Init topo sequence.
init_graph (DagGraph): Graph instance.
sub_graph_size (int): Min sub-graph size.

Returns:
list, available path.
"""
# 1. Sort the pattern by frequency.
sorted_pattern = context.sort_with_beam(init_pattern)
# 2. Put pattern into queue.
queue = PriorityQueue()
for _, pattern_inst in sorted_pattern.items():
queue.put(
SearchPath(pattern=pattern_inst, sequence=init_topo_order,
graph=init_graph,
sub_graph_size=sub_graph_size),
block=False
)

available_path = []
while not queue.empty():
# a. replace pattern in current topo order.
cur_path = queue.get(block=False)
cur_topo_order = cur_path.topo_order_aft_repl
# b. generate new pattern based on replaced topo order.
if _is_satisfied(cur_path):
available_path.append(cur_path)
continue

if len(available_path) >= 8:
break

for _, cur_pattern in cur_path.new_pattern.items():
if cur_pattern.count < MINI_FREQUENCEY:
available_path.append(cur_path)
break
key = "/".join([cur_pattern.pattern, gen_hash_key(cur_topo_order)])
# c. create new SearchPath.
new_path = SearchPath(pattern=cur_pattern, sequence=cur_topo_order, prev_path=cur_path,
sub_graph_size=sub_graph_size)
context.visited.add(key)
# d. put it into heap to sort.
queue.put(new_path, block=False)

return available_path


def _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=4):
"""
Sub-graph matching.

Args:
init_dag (DagGraph): Graph instance.
beam_width (int): Beam width used to prune search path.
sub_graph_size (int): Mini sub-graph size to find.

Returns:
SearchPath, found path.
"""
context.set_beam_width(beam_width)

def _get_top_1(available_path: list):
if len(available_path) <= 1:
return available_path
available_path = sorted(available_path, key=lambda x: x.actual_v, reverse=True)
return available_path[0] if available_path else None

topo_order = [node for _, (_, node) in enumerate(context.node_collection.items())]
context.set_sequence_length(len(topo_order))
pattern = generate_pattern(topo_order, dag=init_dag, sub_graph_size=sub_graph_size)
found_path = _search(pattern, topo_order, init_graph=init_dag,
sub_graph_size=sub_graph_size)
return _get_top_1(found_path)


def _retrieve_scope_name(found_path):
"""
Retrieve scope name.

Args:
found_path: Found path.
"""
module_dict = dict()
for module_path in found_path.recursion_path:
key, val = _retrieve_operators(module_path, module_dict)
module_dict[key] = val
if found_path.pattern:
key, val = _retrieve_operators(found_path, module_dict)
module_dict[key] = val

topo_order_with_scope_name = []
for node in found_path.topo_order_aft_repl:
if module_dict.get(node.op_type):
topo_order_with_scope_name += [f"Model/{item}" for item in _scope_name_deduplication(
node.op_type, module_dict[node.op_type])]
else:
topo_order_with_scope_name.append(f"Model/{node.op_type}")
return topo_order_with_scope_name


def _scope_name_deduplication(key, scope_names) -> list:
"""
Scope name deduplication.

Args:
scope_names (list): Scope names.

Returns:
list, renamed scope name.
"""
result = []
if key not in module_name_mgr:
module_name_mgr[key] = 0
for item in scope_names:
item = item.replace(key, f"{key}_{module_name_mgr.get(key)}")
result.append(item)
module_name_mgr[key] += 1
return result


def _retrieve_operators(module_path, module_dict):
"""
Retrieve operators from path.

Args:
module_path(SearchPath): module path.
module_dict(dict): module dictionary.

Returns:
str: module_name, operators in module.
"""
global module_name_mgr
node_in_pattern = module_path.pattern.pattern.split('->')
node_list = []
for node in node_in_pattern:
if module_dict.get(node):
node_list += module_dict[node]
else:
node_list.append(node)
key = module_path.pattern.module_name
val = [f"{key}/{node}" for node in node_list]
return key, val


def _build_connection(loader):
"""
Build dag graph.

Args:
loader (OnnxDataLoader): Dataloader.
"""
context.set_init_node_collection(loader.nodes_dict)
# Output name is not same with node name
for node_name, node in loader.nodes_dict.items():
context.precursor_table[node_name] = list(node.get_precursor_dict().keys())
context.successor_table[node_name] = list(node.get_successor_dict().keys())

dag = DagGraph(nodes=context.node_collection.copy(),
precursor=context.precursor_table.copy(),
successor=context.successor_table.copy())
return dag


def generate_scope_name(data_loader):
"""
Generate scope name according to computation graph.

Args:
data_loader (OnnxDataLoader): Data loader instance.

Returns:
list[str], generated scope name.
"""
init_dag = _build_connection(data_loader)
result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6)
topo_order_with_scope_name_list = _retrieve_scope_name(result)
return topo_order_with_scope_name_list

+ 10
- 9
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define ONNX related opertions."""
"""Define ONNX related operations."""
import re
import abc
from collections import OrderedDict

from mindinsight.mindconverter.common.log import logger as log

from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING,\
from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT


@@ -27,7 +27,7 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None
"""
Convert Tensorflow model to ONNX model.

Note: shape overide is supported by tf2onnx but we
Note: shape override is supported by tf2onnx but we
have not supported yet.

Args:
@@ -87,6 +87,7 @@ class OnnxTensor:
raw_tensor (onnx.TensorProto): onnx.TensorProto instance.
"""
import onnx

def __init__(self, raw_tensor):
self.raw_tensor = raw_tensor
self.name = raw_tensor.name
@@ -279,8 +280,9 @@ class OnnxDataLoader:
self.inferred_model = onnx.shape_inference.infer_shapes(self.model)

def _parse_value_info(self): # no input node & output node
"""Parse onnx defined value_info class attribtues"""
"""Parse onnx defined value_info class attributes."""
import onnx

def _parse_value_info_re(i):
"""
Parse the value_info by regular expression
@@ -297,7 +299,7 @@ class OnnxDataLoader:
i_type = group_match.group('type')
i_dim_str = group_match.group('dim_str')

return (i_name, i_type, i_dim_str)
return i_name, i_type, i_dim_str

if not self.inferred_model:
return
@@ -333,18 +335,17 @@ class OnnxDataLoader:
This function has a prerequisite of the shape inference.
"""
for (node_name, (_, shape_str)) in self.value_info_dict.items():
l = []
lst = []
# split shape by 'x'
shape_list = shape_str.split('x')
# replace unknown shape by '-1'
for s in shape_list:
if 'unk' in s:
s = '-1'

# convert str to int
s = int(s)
l.append(s)
self.node_output_shape_dict[node_name] = l
lst.append(s)
self.node_output_shape_dict[node_name] = lst

def get_node(self, node_name):
"""Get the OnnxNode instance by node name."""


Loading…
Cancel
Save