| @@ -35,6 +35,25 @@ def is_converted(operation: str): | |||||
| return operation and SEPARATOR_IN_ONNX_OP not in operation | return operation and SEPARATOR_IN_ONNX_OP not in operation | ||||
| def _add_outputs_of_onnx_model(model, output_nodes: List[str]): | |||||
| """ | |||||
| Add output nodes of onnx model. | |||||
| Args: | |||||
| model (ModelProto): ONNX model. | |||||
| output_nodes (list[str]): Output nodes list. | |||||
| Returns: | |||||
| ModelProto, edited ONNX model. | |||||
| """ | |||||
| onnx = import_module("onnx") | |||||
| for opt_name in output_nodes: | |||||
| intermediate_layer_value_info = onnx.helper.ValueInfoProto() | |||||
| intermediate_layer_value_info.name = opt_name | |||||
| model.graph.output.append(intermediate_layer_value_info) | |||||
| return model | |||||
| def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): | def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): | ||||
| """ | """ | ||||
| Fetch specific nodes output from onnx model. | Fetch specific nodes output from onnx model. | ||||
| @@ -54,13 +73,10 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] | |||||
| raise TypeError("`feed_dict` should be type of dict, and `output_nodes` " | raise TypeError("`feed_dict` should be type of dict, and `output_nodes` " | ||||
| "should be type of List[str].") | "should be type of List[str].") | ||||
| ort = import_module("onnxruntime") | |||||
| input_nodes = list(feed_dict.keys()) | |||||
| edit_model = _add_outputs_of_onnx_model(model, output_nodes) | |||||
| extractor = getattr(import_module("onnx.utils"), "Extractor")(model) | |||||
| extracted_model = extractor.extract_model(input_nodes, output_nodes) | |||||
| sess = ort.InferenceSession(path_or_bytes=bytes(extracted_model.SerializeToString())) | |||||
| ort = import_module("onnxruntime") | |||||
| sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString())) | |||||
| fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) | fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) | ||||
| run_result = dict() | run_result = dict() | ||||
| for idx, opt in enumerate(output_nodes): | for idx, opt in enumerate(output_nodes): | ||||