|
|
|
@@ -34,18 +34,17 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { |
|
|
|
std::vector<AnfNodePtr> res; |
|
|
|
if (root == nullptr) { |
|
|
|
return res; |
|
|
|
} |
|
|
|
size_t seen = NewSeenGeneration(); |
|
|
|
std::deque<AnfNodePtr> todo(1024); |
|
|
|
std::vector<AnfNodePtr> res; |
|
|
|
todo.clear(); |
|
|
|
todo.push_back(root); |
|
|
|
|
|
|
|
while (!todo.empty()) { |
|
|
|
AnfNodePtr node = todo.back(); |
|
|
|
if (node == nullptr) { |
|
|
|
todo.pop_back(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (node->extra_seen_ == seen) { // We use extra_seen_ as finish flag |
|
|
|
todo.pop_back(); |
|
|
|
continue; |
|
|
|
@@ -65,10 +64,8 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c |
|
|
|
node->seen_ = seen; |
|
|
|
if (incl == FOLLOW) { |
|
|
|
auto succs = succ(node); |
|
|
|
(void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen](const AnfNodePtr &next) { |
|
|
|
return next != nullptr && next->seen_ != seen && |
|
|
|
(next->func_graph() == nullptr || next->func_graph()->get_return() != next); |
|
|
|
}); |
|
|
|
(void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), |
|
|
|
[seen](const AnfNodePtr &next) { return next != nullptr && next->seen_ != seen; }); |
|
|
|
} else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE |
|
|
|
MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\""; |
|
|
|
} |
|
|
|
|