diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index 811da4e2..2be81a37 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -35,6 +35,25 @@ def is_converted(operation: str): return operation and SEPARATOR_IN_ONNX_OP not in operation +def _add_outputs_of_onnx_model(model, output_nodes: List[str]): + """ + Add output nodes of onnx model. + + Args: + model (ModelProto): ONNX model. + output_nodes (list[str]): Output nodes list. + + Returns: + ModelProto, edited ONNX model. + """ + onnx = import_module("onnx") + for opt_name in output_nodes: + intermediate_layer_value_info = onnx.helper.ValueInfoProto() + intermediate_layer_value_info.name = opt_name + model.graph.output.append(intermediate_layer_value_info) + return model + + def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): """ Fetch specific nodes output from onnx model. @@ -54,13 +73,10 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] raise TypeError("`feed_dict` should be type of dict, and `output_nodes` " "should be type of List[str].") - ort = import_module("onnxruntime") - - input_nodes = list(feed_dict.keys()) + edit_model = _add_outputs_of_onnx_model(model, output_nodes) - extractor = getattr(import_module("onnx.utils"), "Extractor")(model) - extracted_model = extractor.extract_model(input_nodes, output_nodes) - sess = ort.InferenceSession(path_or_bytes=bytes(extracted_model.SerializeToString())) + ort = import_module("onnxruntime") + sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString())) fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) run_result = dict() for idx, opt in enumerate(output_nodes):