|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <vector> |
|
|
|
#include <set> |
|
|
|
#include <algorithm> |
|
|
|
#include <memory> |
|
|
|
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" |
|
|
|
@@ -27,13 +28,53 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
|
|
|
|
void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { |
|
|
|
for (auto &subgraph : graph->subGraph) { |
|
|
|
for (auto &idx : subgraph->nodeIndices) { |
|
|
|
if (idx > node_idx) { |
|
|
|
idx--; |
|
|
|
} |
|
|
|
std::set<uint32_t> SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph, |
|
|
|
schema::MetaGraphT *graph) { |
|
|
|
std::set<uint32_t> tensors_indices{}; |
|
|
|
for (auto &node_idx : subgraph->nodeIndices) { |
|
|
|
auto &node = graph->nodes.at(node_idx); |
|
|
|
for (auto &input_idx : node->inputIndex) { |
|
|
|
tensors_indices.insert(input_idx); |
|
|
|
} |
|
|
|
for (auto &output_idx : node->outputIndex) { |
|
|
|
tensors_indices.insert(output_idx); |
|
|
|
} |
|
|
|
} |
|
|
|
return tensors_indices; |
|
|
|
} |
|
|
|
|
|
|
|
bool SubgraphNodePass::IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node, |
|
|
|
const std::unique_ptr<SubGraphT> &subgraph) { |
|
|
|
return (std::any_of(node->inputIndex.begin(), node->inputIndex.end(), |
|
|
|
[&tensors_indices, &subgraph](uint32_t idx) { |
|
|
|
return tensors_indices.count(idx) > 0 || IsContain(subgraph->inputIndices, idx); |
|
|
|
})) && |
|
|
|
(std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) { |
|
|
|
return tensors_indices.count(idx) > 0 || IsContain(subgraph->outputIndices, idx); |
|
|
|
})); |
|
|
|
} |
|
|
|
|
|
|
|
void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { |
|
|
|
for (auto &subgraph : graph->subGraph) { |
|
|
|
std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(), |
|
|
|
[&node_idx](uint32_t idx) { |
|
|
|
if (idx > node_idx) { |
|
|
|
return --idx; |
|
|
|
} |
|
|
|
return idx; |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SubgraphNodePass::IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { |
|
|
|
for (auto &subgraph : graph->subGraph) { |
|
|
|
std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(), |
|
|
|
[&node_idx](uint32_t idx) { |
|
|
|
if (idx >= node_idx) { |
|
|
|
return ++idx; |
|
|
|
} |
|
|
|
return idx; |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -50,7 +91,7 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { |
|
|
|
auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx); |
|
|
|
if (node_idx_pos != subgraph->nodeIndices.end()) { |
|
|
|
subgraph->nodeIndices.erase(node_idx_pos); |
|
|
|
UpdateSubgraphNodeIndices(node_idx, graph); |
|
|
|
DecreaseSubgraphNodeIndices(node_idx, graph); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -62,10 +103,12 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { |
|
|
|
|
|
|
|
for (uint32_t i = 0; i < new_nodes.size(); i++) { |
|
|
|
if (!IsContain(old_nodes_, new_nodes[i])) { |
|
|
|
auto &node = graph->nodes.at(i); |
|
|
|
for (auto &subgraph : graph->subGraph) { |
|
|
|
if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) { |
|
|
|
subgraph->nodeIndices.push_back(old_nodes_.size()); |
|
|
|
old_nodes_.push_back(new_nodes[i]); |
|
|
|
auto tensors_indices = GetSubgraphAllTensorIndices(subgraph, graph); |
|
|
|
if (IsNodeInSubgraph(tensors_indices, node, subgraph)) { |
|
|
|
IncreaseSubgraphNodeIndices(i, graph); |
|
|
|
subgraph->nodeIndices.push_back(i); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|