From fb4bd856567dce45e58c9e7c9847650e5e96ec70 Mon Sep 17 00:00:00 2001 From: yefeng Date: Tue, 22 Dec 2020 15:35:51 +0800 Subject: [PATCH] 0115-fix-convert-2 --- .../graph/subgraph_node_pass.cc | 63 ++++++++++++++++--- .../graph/subgraph_node_pass.h | 8 ++- 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc index 99507ba665..e2ad9c11a4 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc @@ -15,6 +15,7 @@ */ #include +#include #include #include #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" @@ -27,13 +28,53 @@ namespace mindspore { namespace lite { -void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { - for (auto &subgraph : graph->subGraph) { - for (auto &idx : subgraph->nodeIndices) { - if (idx > node_idx) { - idx--; - } +std::set SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr &subgraph, + schema::MetaGraphT *graph) { + std::set tensors_indices{}; + for (auto &node_idx : subgraph->nodeIndices) { + auto &node = graph->nodes.at(node_idx); + for (auto &input_idx : node->inputIndex) { + tensors_indices.insert(input_idx); } + for (auto &output_idx : node->outputIndex) { + tensors_indices.insert(output_idx); + } + } + return tensors_indices; +} + +bool SubgraphNodePass::IsNodeInSubgraph(const std::set &tensors_indices, const std::unique_ptr &node, + const std::unique_ptr &subgraph) { + return (std::any_of(node->inputIndex.begin(), node->inputIndex.end(), + [&tensors_indices, &subgraph](uint32_t idx) { + return tensors_indices.count(idx) > 0 || IsContain(subgraph->inputIndices, idx); + })) && + (std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) { + return tensors_indices.count(idx) > 0 || IsContain(subgraph->outputIndices, idx); + })); +} + +void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { + for (auto &subgraph : graph->subGraph) { + std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(), + [&node_idx](uint32_t idx) { + if (idx > node_idx) { + return --idx; + } + return idx; + }); + } +} + +void SubgraphNodePass::IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { + for (auto &subgraph : graph->subGraph) { + std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(), + [&node_idx](uint32_t idx) { + if (idx >= node_idx) { + return ++idx; + } + return idx; + }); } } @@ -50,7 +91,7 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx); if (node_idx_pos != subgraph->nodeIndices.end()) { subgraph->nodeIndices.erase(node_idx_pos); - UpdateSubgraphNodeIndices(node_idx, graph); + DecreaseSubgraphNodeIndices(node_idx, graph); break; } } @@ -62,10 +103,12 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { for (uint32_t i = 0; i < new_nodes.size(); i++) { if (!IsContain(old_nodes_, new_nodes[i])) { + auto &node = graph->nodes.at(i); for (auto &subgraph : graph->subGraph) { - if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) { - subgraph->nodeIndices.push_back(old_nodes_.size()); - old_nodes_.push_back(new_nodes[i]); + auto tensors_indices = GetSubgraphAllTensorIndices(subgraph, graph); + if (IsNodeInSubgraph(tensors_indices, node, subgraph)) { + IncreaseSubgraphNodeIndices(i, graph); + subgraph->nodeIndices.push_back(i); } } } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h index 412cbe6211..303310b9c3 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include "tools/converter/optimizer.h" namespace mindspore { @@ -32,7 +34,11 @@ class SubgraphNodePass : public GraphPass { STATUS Run(schema::MetaGraphT *graph) override; private: - void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph); + void DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph); + void IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph); + std::set GetSubgraphAllTensorIndices(const std::unique_ptr &subgraph, schema::MetaGraphT *graph); + bool IsNodeInSubgraph(const std::set &tensors_indices, const std::unique_ptr &node, + const std::unique_ptr &subgraph); std::vector old_nodes_; }; } // namespace lite