| @@ -286,6 +286,9 @@ void KernelGraph::GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num) { | |||||
| MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges."; | MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges."; | ||||
| return; | return; | ||||
| } | } | ||||
| if (*loop_num != 0) { | |||||
| return; | |||||
| } | |||||
| visited_nodes_.insert(node); | visited_nodes_.insert(node); | ||||
| for (auto input_edge : node_input_edges_[node]) { | for (auto input_edge : node_input_edges_[node]) { | ||||
| size_t input_num = node_input_num_[input_edge.first]; | size_t input_num = node_input_num_[input_edge.first]; | ||||
| @@ -300,19 +303,24 @@ void KernelGraph::GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num) { | |||||
| AnfNodePtr node_iter = node; | AnfNodePtr node_iter = node; | ||||
| MS_EXCEPTION_IF_NULL(node_iter); | MS_EXCEPTION_IF_NULL(node_iter); | ||||
| MS_LOG(DEBUG) << "Print loop nodes start:"; | MS_LOG(DEBUG) << "Print loop nodes start:"; | ||||
| for (; node_iter != input_edge.first; node_iter = edge_to_[node_iter]) { | |||||
| MS_EXCEPTION_IF_NULL(node_iter); | |||||
| for (; node_iter != input_edge.first && node_iter != nullptr; node_iter = edge_to_[node_iter]) { | |||||
| loop_nodes_.push(node_iter); | loop_nodes_.push(node_iter); | ||||
| node_input_num_[node_iter]--; | node_input_num_[node_iter]--; | ||||
| MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); | MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); | ||||
| } | } | ||||
| loop_nodes_.push(node_iter); | |||||
| loop_nodes_.push(node); | |||||
| (*loop_num)++; | |||||
| node_input_num_[node_iter]--; | |||||
| MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); | |||||
| MS_LOG(DEBUG) << "Get loop node:" << node->DebugString(); | |||||
| MS_LOG(DEBUG) << "Print loop nodes end, Loop num:" << *loop_num; | |||||
| if (node_iter != nullptr) { | |||||
| loop_nodes_.push(node_iter); | |||||
| loop_nodes_.push(node); | |||||
| (*loop_num)++; | |||||
| node_input_num_[node_iter]--; | |||||
| MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); | |||||
| MS_LOG(DEBUG) << "Get loop node:" << node->DebugString(); | |||||
| MS_LOG(DEBUG) << "Print loop nodes end, Loop num:" << *loop_num; | |||||
| while (!loop_nodes_.empty()) { | |||||
| loop_nodes_.pop(); | |||||
| } | |||||
| return; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||