|
|
|
@@ -50,90 +50,127 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const { |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::SetExecOrderByDefault() { |
|
|
|
BfsToUpdateNodeOutput(); |
|
|
|
std::stack<AnfNodePtr> seed_nodes; |
|
|
|
UpdateNodeEdgeList(&seed_nodes); |
|
|
|
execution_order_.clear(); |
|
|
|
std::queue<AnfNodePtr> allreduce_nodes; |
|
|
|
std::queue<AnfNodePtr> zero_output_nodes; |
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes; |
|
|
|
auto clear_output = [&zero_output_nodes, &allreduce_nodes, &visited_nodes, this](const AnfNodePtr &input) -> void { |
|
|
|
if (node_output_num_[input] == 0 && visited_nodes.find(input) == visited_nodes.end()) { |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
MS_LOG(DEBUG) << "Clear output num:" << input->DebugString(); |
|
|
|
(void)visited_nodes.insert(input); |
|
|
|
if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == kAllReduceOpName) { |
|
|
|
allreduce_nodes.push(input); |
|
|
|
} else { |
|
|
|
zero_output_nodes.push(input); |
|
|
|
std::queue<AnfNodePtr> zero_input_nodes; |
|
|
|
|
|
|
|
auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) { |
|
|
|
auto it = node_output_edges_.find(node); |
|
|
|
if (it == node_output_edges_.end()) { |
|
|
|
// value node and parameter has no input,no need to print log |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// visit all reduce node first, then other nodes |
|
|
|
std::vector<AnfNodePtr> active_nodes; |
|
|
|
for (const auto &output_edge : it->second) { |
|
|
|
auto next_node = output_edge.first; |
|
|
|
if (node_input_num_.find(next_node) == node_input_num_.end()) { |
|
|
|
MS_EXCEPTION_IF_NULL(next_node); |
|
|
|
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(next_node); |
|
|
|
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() |
|
|
|
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; |
|
|
|
if (node_input_num_[next_node] < output_edge.second) { |
|
|
|
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" |
|
|
|
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second; |
|
|
|
} |
|
|
|
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; |
|
|
|
// allreduce first |
|
|
|
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) { |
|
|
|
(void)visited_nodes.insert(next_node); |
|
|
|
if (AnfAlgo::IsAllReduceOp(next_node)) { |
|
|
|
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString(); |
|
|
|
visit_queue->push(next_node); |
|
|
|
} else { |
|
|
|
active_nodes.emplace_back(next_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &node : active_nodes) { |
|
|
|
MS_LOG(DEBUG) << "visit node:" << node->DebugString(); |
|
|
|
visit_queue->push(node); |
|
|
|
} |
|
|
|
}; |
|
|
|
zero_output_nodes.emplace(get_return()); |
|
|
|
while (!zero_output_nodes.empty() || !allreduce_nodes.empty()) { |
|
|
|
AnfNodePtr node; |
|
|
|
if (!zero_output_nodes.empty()) { |
|
|
|
node = zero_output_nodes.front(); |
|
|
|
zero_output_nodes.pop(); |
|
|
|
|
|
|
|
AnfNodePtr last_allreduce_node = nullptr; |
|
|
|
std::queue<AnfNodePtr> allreduce_descendants; |
|
|
|
while (!seed_nodes.empty() || last_allreduce_node != nullptr) { |
|
|
|
// seed nodes first, then visit last all reduce node descendant |
|
|
|
if (seed_nodes.empty()) { |
|
|
|
visit_node_descendant(last_allreduce_node, &allreduce_descendants); |
|
|
|
last_allreduce_node = nullptr; |
|
|
|
} else { |
|
|
|
node = allreduce_nodes.front(); |
|
|
|
allreduce_nodes.pop(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) { |
|
|
|
execution_order_.push_back(node->cast<CNodePtr>()); |
|
|
|
zero_input_nodes.push(seed_nodes.top()); |
|
|
|
seed_nodes.pop(); |
|
|
|
} |
|
|
|
auto it = node_input_edges_.find(node); |
|
|
|
if (it == node_input_edges_.end()) { |
|
|
|
// value node and parameter has no input,no need to print log |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; |
|
|
|
|
|
|
|
// all reduce node descendant first, then common queue |
|
|
|
while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) { |
|
|
|
AnfNodePtr node = nullptr; |
|
|
|
bool is_allreduce_descendant = false; |
|
|
|
if (allreduce_descendants.empty()) { |
|
|
|
node = zero_input_nodes.front(); |
|
|
|
zero_input_nodes.pop(); |
|
|
|
} else { |
|
|
|
node = allreduce_descendants.front(); |
|
|
|
allreduce_descendants.pop(); |
|
|
|
is_allreduce_descendant = true; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (const auto &input_edge : it->second) { |
|
|
|
if (node_output_num_.find(input_edge.first) == node_output_num_.end()) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_edge.first); |
|
|
|
MS_LOG(EXCEPTION) << "Can't find node[" << input_edge.first->DebugString() << "]"; |
|
|
|
// add execute node |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) { |
|
|
|
execution_order_.push_back(node->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(input_edge.first); |
|
|
|
MS_LOG(DEBUG) << "Decrease input:" << input_edge.first->DebugString() << ",node:" << node->DebugString() |
|
|
|
<< ",num: " << node_output_num_[input_edge.first] << ",decrease num:" << input_edge.second; |
|
|
|
if (node_output_num_[input_edge.first] < input_edge.second) { |
|
|
|
MS_LOG(EXCEPTION) << "Input node:" << input_edge.first->DebugString() << ",node_output_num" |
|
|
|
<< node_output_num_[input_edge.first] << "depend edge:" << input_edge.second; |
|
|
|
// for all reduce node, visit last all reduce node descendant |
|
|
|
if (AnfAlgo::IsAllReduceOp(node)) { |
|
|
|
if (last_allreduce_node != nullptr) { |
|
|
|
visit_node_descendant(last_allreduce_node, &allreduce_descendants); |
|
|
|
} |
|
|
|
last_allreduce_node = node; |
|
|
|
} else if (is_allreduce_descendant) { |
|
|
|
visit_node_descendant(node, &allreduce_descendants); |
|
|
|
} else { |
|
|
|
visit_node_descendant(node, &zero_input_nodes); |
|
|
|
} |
|
|
|
node_output_num_[input_edge.first] = node_output_num_[input_edge.first] - input_edge.second; |
|
|
|
clear_output(input_edge.first); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CheckLoop(); |
|
|
|
std::reverse(execution_order_.begin(), execution_order_.end()); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::CheckLoop() { |
|
|
|
std::map<AnfNodePtr, size_t> none_zero_output; |
|
|
|
if (node_output_edges_.size() != node_output_num_.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "node_output_edges_ size :" << node_output_edges_.size() |
|
|
|
<< "not equal to node_output_num_ size:" << node_output_num_.size(); |
|
|
|
std::map<AnfNodePtr, size_t> none_zero_nodes; |
|
|
|
if (node_input_edges_.size() != node_input_num_.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size() |
|
|
|
<< "not equal to node_input_num_ size:" << node_input_num_.size(); |
|
|
|
} |
|
|
|
for (auto &it : node_output_num_) { |
|
|
|
for (auto &it : node_input_num_) { |
|
|
|
MS_EXCEPTION_IF_NULL(it.first); |
|
|
|
string str; |
|
|
|
auto node_output_it = node_output_edges_.find(it.first); |
|
|
|
if (node_output_it == node_output_edges_.end()) { |
|
|
|
auto node_input_it = node_input_edges_.find(it.first); |
|
|
|
if (node_input_it == node_input_edges_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]"; |
|
|
|
} |
|
|
|
for (const auto &output_edge : node_output_edges_[it.first]) { |
|
|
|
MS_EXCEPTION_IF_NULL(output_edge.first); |
|
|
|
str = str.append(output_edge.first->DebugString()).append("|"); |
|
|
|
for (const auto &input_edge : node_input_edges_[it.first]) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_edge.first); |
|
|
|
str = str.append(input_edge.first->DebugString()).append("|"); |
|
|
|
} |
|
|
|
if (it.second != 0) { |
|
|
|
MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",outputs:" << str << ",output num:" << it.second; |
|
|
|
none_zero_output[it.first] = it.second; |
|
|
|
MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second; |
|
|
|
none_zero_nodes[it.first] = it.second; |
|
|
|
} |
|
|
|
} |
|
|
|
// if don't consider control depend and loop exit,a exception will be throw |
|
|
|
if (!none_zero_output.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_output.size(); |
|
|
|
if (!none_zero_nodes.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -346,12 +383,13 @@ void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, |
|
|
|
} else { |
|
|
|
input_it->second.push_back(input_depend_edge); |
|
|
|
} |
|
|
|
// add the depend sum of node |
|
|
|
auto depend_it = node_output_num_.find(input); |
|
|
|
if (depend_it == node_output_num_.end()) { |
|
|
|
node_output_num_[input] = 0; |
|
|
|
// add node input depend num |
|
|
|
auto depend_it = node_input_num_.find(node); |
|
|
|
if (depend_it == node_input_num_.end()) { |
|
|
|
node_input_num_[node] = depend_edge_num; |
|
|
|
} else { |
|
|
|
depend_it->second += depend_edge_num; |
|
|
|
} |
|
|
|
node_output_num_[input] += depend_edge_num; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) { |
|
|
|
@@ -429,9 +467,9 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::BfsToUpdateNodeOutput() { |
|
|
|
void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) { |
|
|
|
node_output_edges_.clear(); |
|
|
|
node_output_num_.clear(); |
|
|
|
node_input_num_.clear(); |
|
|
|
node_input_edges_.clear(); |
|
|
|
std::vector<AnfNodePtr> control_depends; |
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes; |
|
|
|
@@ -441,6 +479,11 @@ void KernelGraph::BfsToUpdateNodeOutput() { |
|
|
|
auto node = que.front(); |
|
|
|
que.pop(); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (node->isa<Parameter>() || node->isa<ValueNode>()) { |
|
|
|
seed_nodes->push(node); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -454,10 +497,6 @@ void KernelGraph::BfsToUpdateNodeOutput() { |
|
|
|
control_depends.push_back(input); |
|
|
|
depend_edge_num = 0; |
|
|
|
} |
|
|
|
// the 2rd input of depend is no depend edge |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && input == cnode->input(kDependAttachNodeIndex)) { |
|
|
|
depend_edge_num = 0; |
|
|
|
} |
|
|
|
PushNoVisitedNode(input, &que, &visited_nodes); |
|
|
|
AddDependEdge(node, input, depend_edge_num); |
|
|
|
} |
|
|
|
|