|
|
|
@@ -45,21 +45,19 @@ bool NestedLoopExpandPass::IsNestedPartial(const std::unique_ptr<CNodeT> &node) |
|
|
|
|
|
|
|
void NestedLoopExpandPass::ReplacePartialNodeWithSubgraph(const std::unique_ptr<SubGraphT> &main_graph) { |
|
|
|
bool is_changed = false; |
|
|
|
for (auto &node_idx : main_graph->nodeIndices) { |
|
|
|
auto &node = graph_->nodes.at(node_idx); |
|
|
|
for (auto iter = main_graph->nodeIndices.begin(); iter != main_graph->nodeIndices.end();) { |
|
|
|
auto &node = graph_->nodes.at(*iter); |
|
|
|
if (!IsNestedPartial(node)) { |
|
|
|
iter++; |
|
|
|
continue; |
|
|
|
} |
|
|
|
is_changed = true; |
|
|
|
auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex; |
|
|
|
auto &this_subgraph = graph_->subGraph.at(subgraph_idx); |
|
|
|
subgraph_to_drop_.push_back(subgraph_idx); |
|
|
|
auto partial_pos = std::find(main_graph->nodeIndices.begin(), main_graph->nodeIndices.end(), node_idx); |
|
|
|
std::vector<uint32_t> tmp; |
|
|
|
tmp.assign(main_graph->nodeIndices.begin(), partial_pos); |
|
|
|
tmp.insert(tmp.end(), this_subgraph->nodeIndices.begin(), this_subgraph->nodeIndices.end()); |
|
|
|
tmp.insert(tmp.end(), partial_pos + 1, main_graph->nodeIndices.end()); |
|
|
|
main_graph->nodeIndices.assign(tmp.begin(), tmp.end()); |
|
|
|
iter = main_graph->nodeIndices.erase(iter); |
|
|
|
main_graph->nodeIndices.insert(iter, this_subgraph->nodeIndices.begin(), this_subgraph->nodeIndices.end()); |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
if (is_changed) { |
|
|
|
|