diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc index 827619c2e5..66164d283d 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc @@ -115,6 +115,7 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { auto &node = graph->nodes.at(i); std::vector contain_node_input_subgraphs{}; std::vector contain_node_output_subgraphs{}; + std::vector contain_subgraphs{}; for (auto &subgraph : graph->subGraph) { std::set tensors_indices{}; int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices); @@ -129,26 +130,26 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { contain_node_output_subgraphs.push_back(subgraph.get()); } } - if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 && - contain_node_output_subgraphs[0] != contain_node_input_subgraphs[0]) { - MS_LOG(ERROR) << "not support single node index insert."; - return RET_ERROR; - } - if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 && - contain_node_output_subgraphs[0] == contain_node_input_subgraphs[0]) { + std::set_intersection(contain_node_input_subgraphs.begin(), contain_node_input_subgraphs.end(), + contain_node_output_subgraphs.begin(), contain_node_output_subgraphs.end(), + inserter(contain_subgraphs, contain_subgraphs.begin())); + if (contain_subgraphs.size() == 1) { IncreaseSubgraphNodeIndices(i, graph); - contain_node_input_subgraphs[0]->nodeIndices.push_back(i); + contain_subgraphs[0]->nodeIndices.push_back(i); continue; } - if (contain_node_input_subgraphs.size() == 1) { + if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.empty()) { IncreaseSubgraphNodeIndices(i, graph); contain_node_input_subgraphs[0]->nodeIndices.push_back(i); continue; } - if (contain_node_output_subgraphs.size() == 1) { + if (contain_node_output_subgraphs.size() == 1 && contain_node_input_subgraphs.empty()) { IncreaseSubgraphNodeIndices(i, graph); contain_node_output_subgraphs[0]->nodeIndices.push_back(i); continue; + } else { + MS_LOG(ERROR) << "Not able to find which subgraph to insert node: " << node->name; + return RET_ERROR; } } }