Browse Source

Fix bugs in conversion of cv model.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
8531be1e68
10 changed files with 41 additions and 23 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  2. +6
    -5
      mindinsight/mindconverter/graph_based_converter/framework.py
  3. +7
    -4
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  4. +5
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  5. +3
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  6. +7
    -3
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  7. +4
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  8. +2
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py
  9. +4
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  10. +2
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -81,7 +81,7 @@ def build_feed_dict(onnx_model, input_nodes: dict):
for node in onnx_model.graph.input for node in onnx_model.graph.input
} }
feed_dict = { feed_dict = {
name: np.random.rand(*shape).astype(input_nodes_types[name.split(":")[0]])
name: np.random.rand(*shape).astype(input_nodes_types[name])
for name, shape in input_nodes.items() for name, shape in input_nodes.items()
} }
return feed_dict return feed_dict


+ 6
- 5
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -83,7 +83,7 @@ def torch_installation_validation(func):
type, inner function. type, inner function.
""" """


def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str,
def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
# Check whether pytorch is installed. # Check whether pytorch is installed.
error_info = None error_info = None
@@ -119,7 +119,7 @@ def torch_installation_validation(func):
_print_error(error) _print_error(error)
sys.exit(0) sys.exit(0)


func(graph_path=graph_path, sample_shape=sample_shape,
func(graph_path=graph_path,
input_nodes=input_nodes, output_nodes=output_nodes, input_nodes=input_nodes, output_nodes=output_nodes,
output_folder=output_folder, report_folder=report_folder) output_folder=output_folder, report_folder=report_folder)


@@ -265,11 +265,12 @@ def main_graph_base_converter(file_config):
if not file_config.get("shape"): if not file_config.get("shape"):
raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")


if graph_path.endswith("pth") and not file_config['input_nodes'] and \
file_config.get("shape") and len(file_config.get("shape")) == 1:
if graph_path.endswith("pth") and not file_config.get("input_nodes", []) and \
file_config.get("shape") and len(file_config.get("shape", ())) == 1:
file_config['input_nodes'] = ["input.1"] file_config['input_nodes'] = ["input.1"]


if len(file_config['shape']) != len(file_config['input_nodes']) != len(set(file_config['input_nodes'])):
if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len(
set(file_config.get("input_nodes", []))):
raise BadParamError("`--shape` and `--input_nodes` must have the same length, " raise BadParamError("`--shape` and `--input_nodes` must have the same length, "
"and no redundant node in `--input_nodes`.") "and no redundant node in `--input_nodes`.")




+ 7
- 4
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -17,12 +17,13 @@
__all__ = ["context", __all__ = ["context",
"gen_hash_key", "gen_hash_key",
"DagGraph", "DagGraph",
"MAX_OUT_DEGREE",
"MAX_DEGREE",
"cal_matching_score", "cal_matching_score",
"ACCEPTABLE_RESULT_COUNT", "ACCEPTABLE_RESULT_COUNT",
"MINI_FREQUENCY", "MINI_FREQUENCY",
"SATISFIED_SCORE", "SATISFIED_SCORE",
"MAX_ITERATION_DEPTH"]
"MAX_ITERATION_DEPTH_OF_MULTI_IPT",
"MAX_ITERATION_DEPTH_OF_SINGLE_IPT"]


import math import math
import copy import copy
@@ -32,9 +33,10 @@ from typing import List


from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode


MAX_OUT_DEGREE = 1
MAX_DEGREE = 1
MINI_FREQUENCY = 0.07 MINI_FREQUENCY = 0.07
MAX_ITERATION_DEPTH = 16
MAX_ITERATION_DEPTH_OF_MULTI_IPT = 16
MAX_ITERATION_DEPTH_OF_SINGLE_IPT = 8
SATISFIED_SCORE = 0.74 SATISFIED_SCORE = 0.74
ACCEPTABLE_RESULT_COUNT = 32 ACCEPTABLE_RESULT_COUNT = 32
PTN_COVERAGE_THRESHOLD = 0.65 PTN_COVERAGE_THRESHOLD = 0.65
@@ -127,6 +129,7 @@ class AlgorithmContext:
precursor_table = {} precursor_table = {}
successor_table = {} successor_table = {}
outputs_table = {} outputs_table = {}
has_multi_inputs = False


def set_init_node_collection(self, nd_col): def set_init_node_collection(self, nd_col):
"""Init node_collection.""" """Init node_collection."""


+ 5
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -21,7 +21,6 @@ class Pattern:


def __init__(self, pattern, pattern_length, in_degree, out_degree, ptn_items: list = None): def __init__(self, pattern, pattern_length, in_degree, out_degree, ptn_items: list = None):
self.pattern = pattern self.pattern = pattern
self.count = 0
self.start_index = [] self.start_index = []
self.end_index = [] self.end_index = []
self.module_name = None self.module_name = None
@@ -37,6 +36,11 @@ class Pattern:
self.additional_score = 0 self.additional_score = 0
self.known_module_name = None self.known_module_name = None


@property
def count(self):
"""Count of the pattern."""
return len(self.start_index)

def insert(self, idx, seq_len): def insert(self, idx, seq_len):
""" """
Insert a new position. Insert a new position.
@@ -49,7 +53,6 @@ class Pattern:
return return
self.start_index.append(idx) self.start_index.append(idx)
self.end_index.append(idx + seq_len) self.end_index.append(idx + seq_len)
self.count += 1


def __str__(self): def __str__(self):
"""Override `str()` method.""" """Override `str()` method."""


+ 3
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -21,7 +21,7 @@ from typing import Dict, List, Callable, Union
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \
is_built_in_pattern is_built_in_pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \
MAX_OUT_DEGREE, cal_matching_score
MAX_DEGREE, cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
@@ -390,7 +390,7 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
dag=dag) dag=dag)


in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag) in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag)
if out_degree > MAX_OUT_DEGREE:
if out_degree > MAX_DEGREE or (not context.has_multi_inputs and in_degree > MAX_DEGREE):
cur_idx += 1 cur_idx += 1
continue continue


@@ -419,6 +419,7 @@ def _post_process_overlap(patterns) -> Dict:
patterns[name].start_index.pop(idx) patterns[name].start_index.pop(idx)
patterns[name].end_index.pop(idx) patterns[name].end_index.pop(idx)
continue continue
prev_end = patterns[name].end_index[idx]
idx += 1 idx += 1
return patterns return patterns




+ 7
- 3
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -17,9 +17,9 @@ from queue import PriorityQueue
from typing import Dict, List from typing import Dict, List


from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \
ACCEPTABLE_RESULT_COUNT
ACCEPTABLE_RESULT_COUNT, MAX_ITERATION_DEPTH_OF_SINGLE_IPT
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \
MAX_ITERATION_DEPTH, SATISFIED_SCORE
MAX_ITERATION_DEPTH_OF_MULTI_IPT, SATISFIED_SCORE
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \
@@ -37,7 +37,9 @@ def _is_satisfied(path):
Returns: Returns:
bool, True or False. bool, True or False.
""" """
if len(path.recursion_path) > MAX_ITERATION_DEPTH:
recursion_depth = MAX_ITERATION_DEPTH_OF_MULTI_IPT if context.has_multi_inputs \
else MAX_ITERATION_DEPTH_OF_SINGLE_IPT
if len(path.recursion_path) > recursion_depth:
return True return True
candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]) candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()])
if not path.new_pattern or not candidate_eval: if not path.new_pattern or not candidate_eval:
@@ -262,6 +264,8 @@ def _build_connection(loader):
context.successor_table[node_name] = list(node.get_successor_dict().keys()) context.successor_table[node_name] = list(node.get_successor_dict().keys())
context.outputs_table[node_name] = node.output_name_list context.outputs_table[node_name] = node.output_name_list


# Record the model inputs count, use it to control the search algorithm.
context.has_multi_inputs = len(loader.input_nodes) > 1
dag = DagGraph(nodes=context.node_collection.copy(), dag = DagGraph(nodes=context.node_collection.copy(),
precursor=context.precursor_table.copy(), precursor=context.precursor_table.copy(),
successor=context.successor_table.copy()) successor=context.successor_table.copy())


+ 4
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -202,6 +202,8 @@ class OnnxGraph(Graph):
else: else:
onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs)
onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input]
if input_nodes not in onnx_inputs:
raise ModelNotSupportError(f"input nodes({input_nodes}) is not in model inputs ({onnx_inputs}).")
for ipt in input_nodes:
if ipt not in onnx_inputs:
raise ModelNotSupportError(f"input nodes({input_nodes}) is not "
f"in model inputs ({onnx_inputs}).")
return onnx_model return onnx_model

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py View File

@@ -93,7 +93,8 @@ class OnnxSimplify:
self._constant_nodes = copy.deepcopy(const_nodes) self._constant_nodes = copy.deepcopy(const_nodes)


@ModelNotSupportError.check_except( @ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
"Error occurs when loading model with given params, please check `--shape`, "
"`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity."
) )
def _onnx_infer(self, infer_inputs_shape): def _onnx_infer(self, infer_inputs_shape):
""" """


+ 4
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -27,7 +27,8 @@ class PyTorchGraphParser(GraphParser):


@classmethod @classmethod
@ModelNotSupportError.check_except( @ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
"Error occurs when loading model with given params, please check `--shape`, "
"`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity."
) )
def parse(cls, model_path: str, **kwargs): def parse(cls, model_path: str, **kwargs):
""" """
@@ -47,8 +48,9 @@ class PyTorchGraphParser(GraphParser):
raise error raise error


try: try:
sample_shape = list(kwargs.get("input_nodes").values())[0]
onnx_model_sim = cls._convert_pytorch_graph_to_onnx( onnx_model_sim = cls._convert_pytorch_graph_to_onnx(
model_path, kwargs['sample_shape'], opset_version=11)
model_path, sample_shape, opset_version=11)
return onnx_model_sim return onnx_model_sim


except ModuleNotFoundError: except ModuleNotFoundError:


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -27,7 +27,8 @@ class TFGraphParser(GraphParser):


@classmethod @classmethod
@ModelNotSupportError.check_except( @ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
"Error occurs when loading model with given params, please check `--shape`, "
"`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity."
) )
def parse(cls, model_path: str, **kwargs): def parse(cls, model_path: str, **kwargs):
""" """


Loading…
Cancel
Save