| @@ -58,14 +58,32 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||||
| node->extra_seen_ = seen; | node->extra_seen_ = seen; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (node->seen_ == seen && node->extra_seen_ != seen) { | |||||
| MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2); | |||||
| } | |||||
| node->seen_ = seen; | node->seen_ = seen; | ||||
| if (incl == FOLLOW) { | if (incl == FOLLOW) { | ||||
| auto succs = succ(node); | auto succs = succ(node); | ||||
| (void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), | |||||
| [seen](const AnfNodePtr &next) { return next != nullptr && next->seen_ != seen; }); | |||||
| (void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen, todo](const AnfNodePtr &next) { | |||||
| if (next == nullptr || next->extra_seen_ == seen) { | |||||
| return false; | |||||
| } | |||||
| if (next->seen_ != seen) { | |||||
| return true; | |||||
| } | |||||
| if (next->func_graph()->get_return() == next) { | |||||
| return false; | |||||
| } | |||||
| // To dump all nodes in a circle. | |||||
| MS_LOG(ERROR) << "Graph cycle exists. Circle is: "; | |||||
| size_t pos = 0; | |||||
| auto circle_node_it = std::find(todo.begin(), todo.end(), next); | |||||
| for (; circle_node_it != todo.end(); circle_node_it++) { | |||||
| auto circle_node = *circle_node_it; | |||||
| if (circle_node->seen_) { | |||||
| MS_LOG(ERROR) << "#" << pos << ": " << circle_node->DebugString(); | |||||
| pos++; | |||||
| } | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Graph cycle exists, node " << next->DebugString(2); | |||||
| }); | |||||
| } else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE | } else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE | ||||
| MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\""; | MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\""; | ||||
| } | } | ||||
| @@ -138,10 +156,6 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | auto &inputs = node->cast<CNodePtr>()->inputs(); | ||||
| (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); | (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); | ||||
| } | } | ||||
| auto graph = node->func_graph(); | |||||
| if (graph->get_return() != nullptr) { | |||||
| vecs.push_back(graph->get_return()); | |||||
| } | |||||
| return vecs; | return vecs; | ||||
| } | } | ||||