|
|
|
@@ -49,80 +49,81 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const { |
|
|
|
return std::vector<AnfNodePtr>(); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::SetExecOrderByDefault() { |
|
|
|
std::stack<AnfNodePtr> seed_nodes; |
|
|
|
UpdateNodeEdgeList(&seed_nodes); |
|
|
|
execution_order_.clear(); |
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes; |
|
|
|
std::queue<AnfNodePtr> zero_input_nodes; |
|
|
|
|
|
|
|
auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) { |
|
|
|
auto it = node_output_edges_.find(node); |
|
|
|
if (it == node_output_edges_.end()) { |
|
|
|
// value node and parameter has no input,no need to print log |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; |
|
|
|
} |
|
|
|
return; |
|
|
|
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, |
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(visit_queue); |
|
|
|
MS_EXCEPTION_IF_NULL(visited_nodes); |
|
|
|
auto it = node_output_edges_.find(node); |
|
|
|
if (it == node_output_edges_.end()) { |
|
|
|
// value node and parameter has no input,no need to print log |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// visit all reduce node first, then other nodes |
|
|
|
std::vector<AnfNodePtr> active_nodes; |
|
|
|
for (const auto &output_edge : it->second) { |
|
|
|
auto next_node = output_edge.first; |
|
|
|
if (node_input_num_.find(next_node) == node_input_num_.end()) { |
|
|
|
MS_EXCEPTION_IF_NULL(next_node); |
|
|
|
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; |
|
|
|
} |
|
|
|
// visit all reduce node first, then other nodes |
|
|
|
std::vector<AnfNodePtr> active_nodes; |
|
|
|
for (const auto &output_edge : it->second) { |
|
|
|
auto next_node = output_edge.first; |
|
|
|
if (node_input_num_.find(next_node) == node_input_num_.end()) { |
|
|
|
MS_EXCEPTION_IF_NULL(next_node); |
|
|
|
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() |
|
|
|
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; |
|
|
|
if (node_input_num_[next_node] < output_edge.second) { |
|
|
|
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" |
|
|
|
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second; |
|
|
|
} |
|
|
|
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; |
|
|
|
// allreduce first |
|
|
|
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) { |
|
|
|
(void)visited_nodes.insert(next_node); |
|
|
|
if (AnfAlgo::IsAllReduceOp(next_node)) { |
|
|
|
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString(); |
|
|
|
visit_queue->push(next_node); |
|
|
|
} else { |
|
|
|
active_nodes.emplace_back(next_node); |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(next_node); |
|
|
|
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() |
|
|
|
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; |
|
|
|
if (node_input_num_[next_node] < output_edge.second) { |
|
|
|
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node] |
|
|
|
<< ",depend edge:" << output_edge.second; |
|
|
|
} |
|
|
|
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; |
|
|
|
// allreduce first |
|
|
|
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { |
|
|
|
(void)visited_nodes->insert(next_node); |
|
|
|
if (AnfAlgo::IsCommunicationOp(next_node)) { |
|
|
|
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString(); |
|
|
|
visit_queue->push(next_node); |
|
|
|
} else { |
|
|
|
active_nodes.emplace_back(next_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &node : active_nodes) { |
|
|
|
MS_LOG(DEBUG) << "visit node:" << node->DebugString(); |
|
|
|
visit_queue->push(node); |
|
|
|
} |
|
|
|
}; |
|
|
|
for (auto &node : active_nodes) { |
|
|
|
MS_LOG(DEBUG) << "visit node:" << node->DebugString(); |
|
|
|
visit_queue->push(node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr last_allreduce_node = nullptr; |
|
|
|
std::queue<AnfNodePtr> allreduce_descendants; |
|
|
|
while (!seed_nodes.empty() || last_allreduce_node != nullptr) { |
|
|
|
void KernelGraph::SetExecOrderByDefault() { |
|
|
|
std::queue<AnfNodePtr> seed_nodes; |
|
|
|
UpdateNodeEdgeList(&seed_nodes); |
|
|
|
execution_order_.clear(); |
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes; |
|
|
|
std::queue<AnfNodePtr> zero_input_nodes; |
|
|
|
AnfNodePtr last_communication_node = nullptr; |
|
|
|
std::queue<AnfNodePtr> communication_descendants; |
|
|
|
while (!seed_nodes.empty() || last_communication_node != nullptr) { |
|
|
|
// seed nodes first, then visit last all reduce node descendant |
|
|
|
if (seed_nodes.empty()) { |
|
|
|
visit_node_descendant(last_allreduce_node, &allreduce_descendants); |
|
|
|
last_allreduce_node = nullptr; |
|
|
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); |
|
|
|
last_communication_node = nullptr; |
|
|
|
} else { |
|
|
|
zero_input_nodes.push(seed_nodes.top()); |
|
|
|
zero_input_nodes.push(seed_nodes.front()); |
|
|
|
seed_nodes.pop(); |
|
|
|
} |
|
|
|
|
|
|
|
// all reduce node descendant first, then common queue |
|
|
|
while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) { |
|
|
|
while (!zero_input_nodes.empty() || !communication_descendants.empty()) { |
|
|
|
AnfNodePtr node = nullptr; |
|
|
|
bool is_allreduce_descendant = false; |
|
|
|
if (allreduce_descendants.empty()) { |
|
|
|
bool is_communication_descendant = false; |
|
|
|
if (communication_descendants.empty()) { |
|
|
|
node = zero_input_nodes.front(); |
|
|
|
zero_input_nodes.pop(); |
|
|
|
} else { |
|
|
|
node = allreduce_descendants.front(); |
|
|
|
allreduce_descendants.pop(); |
|
|
|
is_allreduce_descendant = true; |
|
|
|
node = communication_descendants.front(); |
|
|
|
communication_descendants.pop(); |
|
|
|
is_communication_descendant = true; |
|
|
|
} |
|
|
|
// add execute node |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -130,19 +131,18 @@ void KernelGraph::SetExecOrderByDefault() { |
|
|
|
execution_order_.push_back(node->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
// for all reduce node, visit last all reduce node descendant |
|
|
|
if (AnfAlgo::IsAllReduceOp(node)) { |
|
|
|
if (last_allreduce_node != nullptr) { |
|
|
|
visit_node_descendant(last_allreduce_node, &allreduce_descendants); |
|
|
|
if (AnfAlgo::IsCommunicationOp(node)) { |
|
|
|
if (last_communication_node != nullptr) { |
|
|
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); |
|
|
|
} |
|
|
|
last_allreduce_node = node; |
|
|
|
} else if (is_allreduce_descendant) { |
|
|
|
visit_node_descendant(node, &allreduce_descendants); |
|
|
|
last_communication_node = node; |
|
|
|
} else if (is_communication_descendant) { |
|
|
|
VisitNodeDescendants(node, &communication_descendants, &visited_nodes); |
|
|
|
} else { |
|
|
|
visit_node_descendant(node, &zero_input_nodes); |
|
|
|
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CheckLoop(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -467,7 +467,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) { |
|
|
|
void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) { |
|
|
|
node_output_edges_.clear(); |
|
|
|
node_input_num_.clear(); |
|
|
|
node_input_edges_.clear(); |
|
|
|
@@ -483,7 +483,6 @@ void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) { |
|
|
|
seed_nodes->push(node); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|