Browse Source

fix bug of nested loop expand

tags/v1.2.0-rc1
mengyuanli 4 years ago
parent
commit
a8f0f63e05
1 changed files with 6 additions and 8 deletions
  1. +6
    -8
      mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc

+ 6
- 8
mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc View File

@@ -45,21 +45,19 @@ bool NestedLoopExpandPass::IsNestedPartial(const std::unique_ptr<CNodeT> &node)

void NestedLoopExpandPass::ReplacePartialNodeWithSubgraph(const std::unique_ptr<SubGraphT> &main_graph) {
bool is_changed = false;
for (auto &node_idx : main_graph->nodeIndices) {
auto &node = graph_->nodes.at(node_idx);
for (auto iter = main_graph->nodeIndices.begin(); iter != main_graph->nodeIndices.end();) {
auto &node = graph_->nodes.at(*iter);
if (!IsNestedPartial(node)) {
iter++;
continue;
}
is_changed = true;
auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex;
auto &this_subgraph = graph_->subGraph.at(subgraph_idx);
subgraph_to_drop_.push_back(subgraph_idx);
auto partial_pos = std::find(main_graph->nodeIndices.begin(), main_graph->nodeIndices.end(), node_idx);
std::vector<uint32_t> tmp;
tmp.assign(main_graph->nodeIndices.begin(), partial_pos);
tmp.insert(tmp.end(), this_subgraph->nodeIndices.begin(), this_subgraph->nodeIndices.end());
tmp.insert(tmp.end(), partial_pos + 1, main_graph->nodeIndices.end());
main_graph->nodeIndices.assign(tmp.begin(), tmp.end());
iter = main_graph->nodeIndices.erase(iter);
main_graph->nodeIndices.insert(iter, this_subgraph->nodeIndices.begin(), this_subgraph->nodeIndices.end());
break;
}

if (is_changed) {


Loading…
Cancel
Save