| @@ -81,7 +81,7 @@ def build_feed_dict(onnx_model, input_nodes: dict): | |||||
| for node in onnx_model.graph.input | for node in onnx_model.graph.input | ||||
| } | } | ||||
| feed_dict = { | feed_dict = { | ||||
| name: np.random.rand(*shape).astype(input_nodes_types[name.split(":")[0]]) | |||||
| name: np.random.rand(*shape).astype(input_nodes_types[name]) | |||||
| for name, shape in input_nodes.items() | for name, shape in input_nodes.items() | ||||
| } | } | ||||
| return feed_dict | return feed_dict | ||||
| @@ -83,7 +83,7 @@ def torch_installation_validation(func): | |||||
| type, inner function. | type, inner function. | ||||
| """ | """ | ||||
| def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str, | |||||
| def _f(graph_path: str, input_nodes: dict, output_nodes: List[str], | |||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| # Check whether pytorch is installed. | # Check whether pytorch is installed. | ||||
| error_info = None | error_info = None | ||||
| @@ -119,7 +119,7 @@ def torch_installation_validation(func): | |||||
| _print_error(error) | _print_error(error) | ||||
| sys.exit(0) | sys.exit(0) | ||||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||||
| func(graph_path=graph_path, | |||||
| input_nodes=input_nodes, output_nodes=output_nodes, | input_nodes=input_nodes, output_nodes=output_nodes, | ||||
| output_folder=output_folder, report_folder=report_folder) | output_folder=output_folder, report_folder=report_folder) | ||||
| @@ -265,11 +265,12 @@ def main_graph_base_converter(file_config): | |||||
| if not file_config.get("shape"): | if not file_config.get("shape"): | ||||
| raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | ||||
| if graph_path.endswith("pth") and not file_config['input_nodes'] and \ | |||||
| file_config.get("shape") and len(file_config.get("shape")) == 1: | |||||
| if graph_path.endswith("pth") and not file_config.get("input_nodes", []) and \ | |||||
| file_config.get("shape") and len(file_config.get("shape", ())) == 1: | |||||
| file_config['input_nodes'] = ["input.1"] | file_config['input_nodes'] = ["input.1"] | ||||
| if len(file_config['shape']) != len(file_config['input_nodes']) != len(set(file_config['input_nodes'])): | |||||
| if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len( | |||||
| set(file_config.get("input_nodes", []))): | |||||
| raise BadParamError("`--shape` and `--input_nodes` must have the same length, " | raise BadParamError("`--shape` and `--input_nodes` must have the same length, " | ||||
| "and no redundant node in `--input_nodes`.") | "and no redundant node in `--input_nodes`.") | ||||
| @@ -17,12 +17,13 @@ | |||||
| __all__ = ["context", | __all__ = ["context", | ||||
| "gen_hash_key", | "gen_hash_key", | ||||
| "DagGraph", | "DagGraph", | ||||
| "MAX_OUT_DEGREE", | |||||
| "MAX_DEGREE", | |||||
| "cal_matching_score", | "cal_matching_score", | ||||
| "ACCEPTABLE_RESULT_COUNT", | "ACCEPTABLE_RESULT_COUNT", | ||||
| "MINI_FREQUENCY", | "MINI_FREQUENCY", | ||||
| "SATISFIED_SCORE", | "SATISFIED_SCORE", | ||||
| "MAX_ITERATION_DEPTH"] | |||||
| "MAX_ITERATION_DEPTH_OF_MULTI_IPT", | |||||
| "MAX_ITERATION_DEPTH_OF_SINGLE_IPT"] | |||||
| import math | import math | ||||
| import copy | import copy | ||||
| @@ -32,9 +33,10 @@ from typing import List | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | ||||
| MAX_OUT_DEGREE = 1 | |||||
| MAX_DEGREE = 1 | |||||
| MINI_FREQUENCY = 0.07 | MINI_FREQUENCY = 0.07 | ||||
| MAX_ITERATION_DEPTH = 16 | |||||
| MAX_ITERATION_DEPTH_OF_MULTI_IPT = 16 | |||||
| MAX_ITERATION_DEPTH_OF_SINGLE_IPT = 8 | |||||
| SATISFIED_SCORE = 0.74 | SATISFIED_SCORE = 0.74 | ||||
| ACCEPTABLE_RESULT_COUNT = 32 | ACCEPTABLE_RESULT_COUNT = 32 | ||||
| PTN_COVERAGE_THRESHOLD = 0.65 | PTN_COVERAGE_THRESHOLD = 0.65 | ||||
| @@ -127,6 +129,7 @@ class AlgorithmContext: | |||||
| precursor_table = {} | precursor_table = {} | ||||
| successor_table = {} | successor_table = {} | ||||
| outputs_table = {} | outputs_table = {} | ||||
| has_multi_inputs = False | |||||
| def set_init_node_collection(self, nd_col): | def set_init_node_collection(self, nd_col): | ||||
| """Init node_collection.""" | """Init node_collection.""" | ||||
| @@ -21,7 +21,6 @@ class Pattern: | |||||
| def __init__(self, pattern, pattern_length, in_degree, out_degree, ptn_items: list = None): | def __init__(self, pattern, pattern_length, in_degree, out_degree, ptn_items: list = None): | ||||
| self.pattern = pattern | self.pattern = pattern | ||||
| self.count = 0 | |||||
| self.start_index = [] | self.start_index = [] | ||||
| self.end_index = [] | self.end_index = [] | ||||
| self.module_name = None | self.module_name = None | ||||
| @@ -37,6 +36,11 @@ class Pattern: | |||||
| self.additional_score = 0 | self.additional_score = 0 | ||||
| self.known_module_name = None | self.known_module_name = None | ||||
| @property | |||||
| def count(self): | |||||
| """Count of the pattern.""" | |||||
| return len(self.start_index) | |||||
| def insert(self, idx, seq_len): | def insert(self, idx, seq_len): | ||||
| """ | """ | ||||
| Insert a new position. | Insert a new position. | ||||
| @@ -49,7 +53,6 @@ class Pattern: | |||||
| return | return | ||||
| self.start_index.append(idx) | self.start_index.append(idx) | ||||
| self.end_index.append(idx + seq_len) | self.end_index.append(idx + seq_len) | ||||
| self.count += 1 | |||||
| def __str__(self): | def __str__(self): | ||||
| """Override `str()` method.""" | """Override `str()` method.""" | ||||
| @@ -21,7 +21,7 @@ from typing import Dict, List, Callable, Union | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ | ||||
| is_built_in_pattern | is_built_in_pattern | ||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ | ||||
| MAX_OUT_DEGREE, cal_matching_score | |||||
| MAX_DEGREE, cal_matching_score | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME | ||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping | ||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ | ||||
| @@ -390,7 +390,7 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, | |||||
| dag=dag) | dag=dag) | ||||
| in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag) | in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag) | ||||
| if out_degree > MAX_OUT_DEGREE: | |||||
| if out_degree > MAX_DEGREE or (not context.has_multi_inputs and in_degree > MAX_DEGREE): | |||||
| cur_idx += 1 | cur_idx += 1 | ||||
| continue | continue | ||||
| @@ -419,6 +419,7 @@ def _post_process_overlap(patterns) -> Dict: | |||||
| patterns[name].start_index.pop(idx) | patterns[name].start_index.pop(idx) | ||||
| patterns[name].end_index.pop(idx) | patterns[name].end_index.pop(idx) | ||||
| continue | continue | ||||
| prev_end = patterns[name].end_index[idx] | |||||
| idx += 1 | idx += 1 | ||||
| return patterns | return patterns | ||||
| @@ -17,9 +17,9 @@ from queue import PriorityQueue | |||||
| from typing import Dict, List | from typing import Dict, List | ||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ | ||||
| ACCEPTABLE_RESULT_COUNT | |||||
| ACCEPTABLE_RESULT_COUNT, MAX_ITERATION_DEPTH_OF_SINGLE_IPT | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ | ||||
| MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||||
| MAX_ITERATION_DEPTH_OF_MULTI_IPT, SATISFIED_SCORE | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | ||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ | from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ | ||||
| @@ -37,7 +37,9 @@ def _is_satisfied(path): | |||||
| Returns: | Returns: | ||||
| bool, True or False. | bool, True or False. | ||||
| """ | """ | ||||
| if len(path.recursion_path) > MAX_ITERATION_DEPTH: | |||||
| recursion_depth = MAX_ITERATION_DEPTH_OF_MULTI_IPT if context.has_multi_inputs \ | |||||
| else MAX_ITERATION_DEPTH_OF_SINGLE_IPT | |||||
| if len(path.recursion_path) > recursion_depth: | |||||
| return True | return True | ||||
| candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]) | candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]) | ||||
| if not path.new_pattern or not candidate_eval: | if not path.new_pattern or not candidate_eval: | ||||
| @@ -262,6 +264,8 @@ def _build_connection(loader): | |||||
| context.successor_table[node_name] = list(node.get_successor_dict().keys()) | context.successor_table[node_name] = list(node.get_successor_dict().keys()) | ||||
| context.outputs_table[node_name] = node.output_name_list | context.outputs_table[node_name] = node.output_name_list | ||||
| # Record the model inputs count, use it to control the search algorithm. | |||||
| context.has_multi_inputs = len(loader.input_nodes) > 1 | |||||
| dag = DagGraph(nodes=context.node_collection.copy(), | dag = DagGraph(nodes=context.node_collection.copy(), | ||||
| precursor=context.precursor_table.copy(), | precursor=context.precursor_table.copy(), | ||||
| successor=context.successor_table.copy()) | successor=context.successor_table.copy()) | ||||
| @@ -202,6 +202,8 @@ class OnnxGraph(Graph): | |||||
| else: | else: | ||||
| onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) | onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) | ||||
| onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] | onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] | ||||
| if input_nodes not in onnx_inputs: | |||||
| raise ModelNotSupportError(f"input nodes({input_nodes}) is not in model inputs ({onnx_inputs}).") | |||||
| for ipt in input_nodes: | |||||
| if ipt not in onnx_inputs: | |||||
| raise ModelNotSupportError(f"input nodes({input_nodes}) is not " | |||||
| f"in model inputs ({onnx_inputs}).") | |||||
| return onnx_model | return onnx_model | ||||
| @@ -93,7 +93,8 @@ class OnnxSimplify: | |||||
| self._constant_nodes = copy.deepcopy(const_nodes) | self._constant_nodes = copy.deepcopy(const_nodes) | ||||
| @ModelNotSupportError.check_except( | @ModelNotSupportError.check_except( | ||||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||||
| "Error occurs when loading model with given params, please check `--shape`, " | |||||
| "`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity." | |||||
| ) | ) | ||||
| def _onnx_infer(self, infer_inputs_shape): | def _onnx_infer(self, infer_inputs_shape): | ||||
| """ | """ | ||||
| @@ -27,7 +27,8 @@ class PyTorchGraphParser(GraphParser): | |||||
| @classmethod | @classmethod | ||||
| @ModelNotSupportError.check_except( | @ModelNotSupportError.check_except( | ||||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||||
| "Error occurs when loading model with given params, please check `--shape`, " | |||||
| "`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity." | |||||
| ) | ) | ||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||
| """ | """ | ||||
| @@ -47,8 +48,9 @@ class PyTorchGraphParser(GraphParser): | |||||
| raise error | raise error | ||||
| try: | try: | ||||
| sample_shape = list(kwargs.get("input_nodes").values())[0] | |||||
| onnx_model_sim = cls._convert_pytorch_graph_to_onnx( | onnx_model_sim = cls._convert_pytorch_graph_to_onnx( | ||||
| model_path, kwargs['sample_shape'], opset_version=11) | |||||
| model_path, sample_shape, opset_version=11) | |||||
| return onnx_model_sim | return onnx_model_sim | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| @@ -27,7 +27,8 @@ class TFGraphParser(GraphParser): | |||||
| @classmethod | @classmethod | ||||
| @ModelNotSupportError.check_except( | @ModelNotSupportError.check_except( | ||||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||||
| "Error occurs when loading model with given params, please check `--shape`, " | |||||
| "`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity." | |||||
| ) | ) | ||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||
| """ | """ | ||||