|
|
|
@@ -35,6 +35,25 @@ def is_converted(operation: str): |
|
|
|
return operation and SEPARATOR_IN_ONNX_OP not in operation |
|
|
|
|
|
|
|
|
|
|
|
def _add_outputs_of_onnx_model(model, output_nodes: List[str]): |
|
|
|
""" |
|
|
|
Add output nodes of onnx model. |
|
|
|
|
|
|
|
Args: |
|
|
|
model (ModelProto): ONNX model. |
|
|
|
output_nodes (list[str]): Output nodes list. |
|
|
|
|
|
|
|
Returns: |
|
|
|
ModelProto, edited ONNX model. |
|
|
|
""" |
|
|
|
onnx = import_module("onnx") |
|
|
|
for opt_name in output_nodes: |
|
|
|
intermediate_layer_value_info = onnx.helper.ValueInfoProto() |
|
|
|
intermediate_layer_value_info.name = opt_name |
|
|
|
model.graph.output.append(intermediate_layer_value_info) |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): |
|
|
|
""" |
|
|
|
Fetch specific nodes output from onnx model. |
|
|
|
@@ -54,13 +73,10 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] |
|
|
|
raise TypeError("`feed_dict` should be type of dict, and `output_nodes` " |
|
|
|
"should be type of List[str].") |
|
|
|
|
|
|
|
ort = import_module("onnxruntime") |
|
|
|
|
|
|
|
input_nodes = list(feed_dict.keys()) |
|
|
|
edit_model = _add_outputs_of_onnx_model(model, output_nodes) |
|
|
|
|
|
|
|
extractor = getattr(import_module("onnx.utils"), "Extractor")(model) |
|
|
|
extracted_model = extractor.extract_model(input_nodes, output_nodes) |
|
|
|
sess = ort.InferenceSession(path_or_bytes=bytes(extracted_model.SerializeToString())) |
|
|
|
ort = import_module("onnxruntime") |
|
|
|
sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString())) |
|
|
|
fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) |
|
|
|
run_result = dict() |
|
|
|
for idx, opt in enumerate(output_nodes): |
|
|
|
|