| @@ -14,6 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <unordered_set> | |||||
| #include "if_subgraph_adapter.h" | #include "if_subgraph_adapter.h" | ||||
| #include "subgraph_adapter_factory.h" | #include "subgraph_adapter_factory.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| @@ -95,8 +96,8 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||||
| domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, | domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, | ||||
| std::set<std::string> &all_inputs) const { | std::set<std::string> &all_inputs) const { | ||||
| std::set<std::string> graph_inputs; | |||||
| std::set<std::string> graph_outputs; | |||||
| std::unordered_set<std::string> graph_inputs; | |||||
| std::unordered_set<std::string> graph_outputs; | |||||
| for (int i = 0; i < onnx_graph.node_size(); i++) { | for (int i = 0; i < onnx_graph.node_size(); i++) { | ||||
| ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | ||||
| for (int j = 0; j < node_proto->input_size(); j++) { | for (int j = 0; j < node_proto->input_size(); j++) { | ||||
| @@ -106,10 +107,12 @@ domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx | |||||
| graph_outputs.emplace(node_proto->output(j)); | graph_outputs.emplace(node_proto->output(j)); | ||||
| } | } | ||||
| } | } | ||||
| std::unordered_set<std::string> graph_initializer_tensors; | |||||
| for (int32_t i = 0; i < onnx_graph.initializer_size(); i++) { | |||||
| graph_initializer_tensors.emplace(onnx_graph.initializer(i).name()); | |||||
| } | |||||
| for (const auto &input : graph_inputs) { | for (const auto &input : graph_inputs) { | ||||
| std::set<std::string>::const_iterator out_iter = graph_outputs.find(input); | |||||
| if (out_iter == graph_outputs.end()) { | |||||
| if (graph_outputs.count(input) == 0 && graph_initializer_tensors.count(input) == 0) { | |||||
| // Record input node need to be constructed | // Record input node need to be constructed | ||||
| all_inputs.emplace(input); | all_inputs.emplace(input); | ||||
| } | } | ||||