Browse Source

fix bug of graphdef transform

tags/v1.2.0-rc1
mengyuanli 4 years ago
parent
commit
a61b5b56d1
2 changed files with 25 additions and 6 deletions
  1. +11
    -0
      mindspore/lite/tools/converter/graphdef_transform.cc
  2. +14
    -6
      mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc

+ 11
- 0
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -42,6 +42,7 @@
#include "tools/converter/legacy_optimizer/graph/select_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
#include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h"

using std::string;
namespace mindspore::lite {
@@ -276,6 +277,16 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}

{
Optimizer nestedLoopOptimizer;
nestedLoopOptimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
status = nestedLoopOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run nestedLoopOptimizer graphPasses Failed";
return status;
}
}

return RET_OK;
} // namespace mindspore::lite
} // namespace mindspore::lite

+ 14
- 6
mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc View File

@@ -77,6 +77,20 @@ STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) {
graph_->subGraph.at(idx) = nullptr;
}



for (auto &node_idx : main_graph->nodeIndices) {
auto &node = graph_->nodes.at(node_idx);
if (node->primitive->value.type == PrimitiveType_Partial) {
auto &subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex;
for (auto i = 0; i < subgraph_idx; ++i) {
if (graph_->subGraph.at(subgraph_idx) == nullptr) {
subgraph_idx--;
}
}
}
}

for (auto it = graph_->subGraph.begin(); it != graph_->subGraph.end();) {
if ((*it) == nullptr) {
it = graph_->subGraph.erase(it);
@@ -85,12 +99,6 @@ STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) {
}
}

for (auto &node : graph_->nodes) {
if (node->primitive->value.type == PrimitiveType_Partial) {
((schema::PartialT *)(node->primitive->value.value))->subGraphIndex -= subgraph_to_drop_.size();
}
}

return RET_OK;
}



Loading…
Cancel
Save