|
|
@@ -286,6 +286,9 @@ void KernelGraph::GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num) { |
|
|
MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges."; |
|
|
MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges."; |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
if (*loop_num != 0) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
visited_nodes_.insert(node); |
|
|
visited_nodes_.insert(node); |
|
|
for (auto input_edge : node_input_edges_[node]) { |
|
|
for (auto input_edge : node_input_edges_[node]) { |
|
|
size_t input_num = node_input_num_[input_edge.first]; |
|
|
size_t input_num = node_input_num_[input_edge.first]; |
|
|
@@ -300,19 +303,24 @@ void KernelGraph::GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num) { |
|
|
AnfNodePtr node_iter = node; |
|
|
AnfNodePtr node_iter = node; |
|
|
MS_EXCEPTION_IF_NULL(node_iter); |
|
|
MS_EXCEPTION_IF_NULL(node_iter); |
|
|
MS_LOG(DEBUG) << "Print loop nodes start:"; |
|
|
MS_LOG(DEBUG) << "Print loop nodes start:"; |
|
|
for (; node_iter != input_edge.first; node_iter = edge_to_[node_iter]) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_iter); |
|
|
|
|
|
|
|
|
for (; node_iter != input_edge.first && node_iter != nullptr; node_iter = edge_to_[node_iter]) { |
|
|
loop_nodes_.push(node_iter); |
|
|
loop_nodes_.push(node_iter); |
|
|
node_input_num_[node_iter]--; |
|
|
node_input_num_[node_iter]--; |
|
|
MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); |
|
|
MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); |
|
|
} |
|
|
} |
|
|
loop_nodes_.push(node_iter); |
|
|
|
|
|
loop_nodes_.push(node); |
|
|
|
|
|
(*loop_num)++; |
|
|
|
|
|
node_input_num_[node_iter]--; |
|
|
|
|
|
MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); |
|
|
|
|
|
MS_LOG(DEBUG) << "Get loop node:" << node->DebugString(); |
|
|
|
|
|
MS_LOG(DEBUG) << "Print loop nodes end, Loop num:" << *loop_num; |
|
|
|
|
|
|
|
|
if (node_iter != nullptr) { |
|
|
|
|
|
loop_nodes_.push(node_iter); |
|
|
|
|
|
loop_nodes_.push(node); |
|
|
|
|
|
(*loop_num)++; |
|
|
|
|
|
node_input_num_[node_iter]--; |
|
|
|
|
|
MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); |
|
|
|
|
|
MS_LOG(DEBUG) << "Get loop node:" << node->DebugString(); |
|
|
|
|
|
MS_LOG(DEBUG) << "Print loop nodes end, Loop num:" << *loop_num; |
|
|
|
|
|
while (!loop_nodes_.empty()) { |
|
|
|
|
|
loop_nodes_.pop(); |
|
|
|
|
|
} |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|