Browse Source

Optimize onnx model extractor performance with editing onnx graph directly.

tags/v1.1.0
liuchongming 5 years ago
parent
commit
0cefc0f558
1 changed files with 22 additions and 6 deletions
  1. +22
    -6
      mindinsight/mindconverter/graph_based_converter/common/utils.py

+ 22
- 6
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -35,6 +35,25 @@ def is_converted(operation: str):
return operation and SEPARATOR_IN_ONNX_OP not in operation


def _add_outputs_of_onnx_model(model, output_nodes: List[str]):
"""
Add output nodes of onnx model.

Args:
model (ModelProto): ONNX model.
output_nodes (list[str]): Output nodes list.

Returns:
ModelProto, edited ONNX model.
"""
onnx = import_module("onnx")
for opt_name in output_nodes:
intermediate_layer_value_info = onnx.helper.ValueInfoProto()
intermediate_layer_value_info.name = opt_name
model.graph.output.append(intermediate_layer_value_info)
return model


def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]):
"""
Fetch specific nodes output from onnx model.
@@ -54,13 +73,10 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]
raise TypeError("`feed_dict` should be type of dict, and `output_nodes` "
"should be type of List[str].")

ort = import_module("onnxruntime")

input_nodes = list(feed_dict.keys())
edit_model = _add_outputs_of_onnx_model(model, output_nodes)

extractor = getattr(import_module("onnx.utils"), "Extractor")(model)
extracted_model = extractor.extract_model(input_nodes, output_nodes)
sess = ort.InferenceSession(path_or_bytes=bytes(extracted_model.SerializeToString()))
ort = import_module("onnxruntime")
sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString()))
fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict)
run_result = dict()
for idx, opt in enumerate(output_nodes):


Loading…
Cancel
Save