|
|
|
@@ -180,8 +180,8 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const { |
|
|
|
return std::vector<AnfNodePtr>(1, graph_output); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, |
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) { |
|
|
|
void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, |
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) { |
|
|
|
MS_EXCEPTION_IF_NULL(visit_queue); |
|
|
|
MS_EXCEPTION_IF_NULL(visited_nodes); |
|
|
|
auto it = node_output_edges_.find(node); |
|
|
|
@@ -241,7 +241,7 @@ void KernelGraph::SetExecOrderByDefault() { |
|
|
|
while (!seed_nodes.empty() || !delay_comm_stack.empty()) { |
|
|
|
// seed nodes first, then delay comm nodes |
|
|
|
if (seed_nodes.empty()) { |
|
|
|
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); |
|
|
|
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); |
|
|
|
delay_comm_stack.pop(); |
|
|
|
} else { |
|
|
|
zero_input_nodes.push(seed_nodes.front()); |
|
|
|
@@ -272,16 +272,16 @@ void KernelGraph::SetExecOrderByDefault() { |
|
|
|
} |
|
|
|
if (optimize_comm) { |
|
|
|
while (!delay_comm_stack.empty()) { |
|
|
|
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); |
|
|
|
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); |
|
|
|
delay_comm_stack.pop(); |
|
|
|
} |
|
|
|
delay_comm_stack.push(node); |
|
|
|
} else if (is_fused_comm) { |
|
|
|
delay_comm_stack.push(node); |
|
|
|
} else if (is_communication_descendant) { |
|
|
|
VisitNodeDescendants(node, &communication_descendants, &visited_nodes); |
|
|
|
EnqueueActiveNodes(node, &communication_descendants, &visited_nodes); |
|
|
|
} else { |
|
|
|
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); |
|
|
|
EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|