Browse Source

!10781 [MS][LITE]fix bug of subgraph node pass

From: @mengyuanli
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
02a0eb9764
1 changed files with 11 additions and 10 deletions
  1. +11
    -10
      mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc

+ 11
- 10
mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc View File

@@ -115,6 +115,7 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
auto &node = graph->nodes.at(i); auto &node = graph->nodes.at(i);
std::vector<SubGraphT *> contain_node_input_subgraphs{}; std::vector<SubGraphT *> contain_node_input_subgraphs{};
std::vector<SubGraphT *> contain_node_output_subgraphs{}; std::vector<SubGraphT *> contain_node_output_subgraphs{};
std::vector<SubGraphT *> contain_subgraphs{};
for (auto &subgraph : graph->subGraph) { for (auto &subgraph : graph->subGraph) {
std::set<uint32_t> tensors_indices{}; std::set<uint32_t> tensors_indices{};
int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices); int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices);
@@ -129,26 +130,26 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
contain_node_output_subgraphs.push_back(subgraph.get()); contain_node_output_subgraphs.push_back(subgraph.get());
} }
} }
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 &&
contain_node_output_subgraphs[0] != contain_node_input_subgraphs[0]) {
MS_LOG(ERROR) << "not support single node index insert.";
return RET_ERROR;
}
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 &&
contain_node_output_subgraphs[0] == contain_node_input_subgraphs[0]) {
std::set_intersection(contain_node_input_subgraphs.begin(), contain_node_input_subgraphs.end(),
contain_node_output_subgraphs.begin(), contain_node_output_subgraphs.end(),
inserter(contain_subgraphs, contain_subgraphs.begin()));
if (contain_subgraphs.size() == 1) {
IncreaseSubgraphNodeIndices(i, graph); IncreaseSubgraphNodeIndices(i, graph);
contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
contain_subgraphs[0]->nodeIndices.push_back(i);
continue; continue;
} }
if (contain_node_input_subgraphs.size() == 1) {
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.empty()) {
IncreaseSubgraphNodeIndices(i, graph); IncreaseSubgraphNodeIndices(i, graph);
contain_node_input_subgraphs[0]->nodeIndices.push_back(i); contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
continue; continue;
} }
if (contain_node_output_subgraphs.size() == 1) {
if (contain_node_output_subgraphs.size() == 1 && contain_node_input_subgraphs.empty()) {
IncreaseSubgraphNodeIndices(i, graph); IncreaseSubgraphNodeIndices(i, graph);
contain_node_output_subgraphs[0]->nodeIndices.push_back(i); contain_node_output_subgraphs[0]->nodeIndices.push_back(i);
continue; continue;
} else {
MS_LOG(ERROR) << "Not able to find which subgraph to insert node: " << node->name;
return RET_ERROR;
} }
} }
} }


Loading…
Cancel
Save