|
|
|
@@ -201,17 +201,21 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::SetExecOrderByDefault() { |
|
|
|
std::queue<AnfNodePtr> zero_input_nodes; |
|
|
|
UpdateNodeEdgeList(&zero_input_nodes); |
|
|
|
std::queue<AnfNodePtr> seed_nodes; |
|
|
|
UpdateNodeEdgeList(&seed_nodes); |
|
|
|
execution_order_.clear(); |
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes; |
|
|
|
std::queue<AnfNodePtr> zero_input_nodes; |
|
|
|
AnfNodePtr last_communication_node = nullptr; |
|
|
|
std::queue<AnfNodePtr> communication_descendants; |
|
|
|
while (!zero_input_nodes.empty() || last_communication_node != nullptr) { |
|
|
|
while (!seed_nodes.empty() || last_communication_node != nullptr) { |
|
|
|
// seed nodes first, then visit last all reduce node descendant |
|
|
|
if (last_communication_node != nullptr) { |
|
|
|
if (seed_nodes.empty()) { |
|
|
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); |
|
|
|
last_communication_node = nullptr; |
|
|
|
} else { |
|
|
|
zero_input_nodes.push(seed_nodes.front()); |
|
|
|
seed_nodes.pop(); |
|
|
|
} |
|
|
|
// all reduce node descendant first, then common queue |
|
|
|
while (!zero_input_nodes.empty() || !communication_descendants.empty()) { |
|
|
|
@@ -900,11 +904,14 @@ void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) { |
|
|
|
seed_nodes->push(node); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto cnode = dyn_cast<CNode>(node); |
|
|
|
if (cnode == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (auto &input : cnode->inputs()) { |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
// We push inputs from right to left, so that them can be evaluated from left to right. |
|
|
|
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { |
|
|
|
auto &input = *iter; |
|
|
|
PushNoVisitedNode(input, &que, &visited_nodes); |
|
|
|
AddDependEdge(node, input, 1); |
|
|
|
} |
|
|
|
|