|
|
|
@@ -135,23 +135,6 @@ std::string GetNodeGroup(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
return ""; |
|
|
|
} |
|
|
|
|
|
|
|
bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map<std::string, std::string> *optimized_comm_group) { |
|
|
|
MS_EXCEPTION_IF_NULL(optimized_comm_group); |
|
|
|
auto node_group = GetNodeGroup(node); |
|
|
|
if (node_group.find(kSyncBnGroup) != string::npos) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto node_name = AnfAlgo::GetCNodeName(node); |
|
|
|
auto iter = optimized_comm_group->find(node_name); |
|
|
|
if (iter == optimized_comm_group->end()) { |
|
|
|
(*optimized_comm_group)[node_name] = node_group; |
|
|
|
return true; |
|
|
|
} else if (iter->second == node_group) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const { |
|
|
|
@@ -188,7 +171,6 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// visit all reduce node first, then other nodes |
|
|
|
std::vector<AnfNodePtr> active_nodes; |
|
|
|
for (const auto &output_edge : it->second) { |
|
|
|
@@ -209,7 +191,9 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP |
|
|
|
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { |
|
|
|
(void)visited_nodes->insert(next_node); |
|
|
|
bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node); |
|
|
|
if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) { |
|
|
|
if (AnfAlgo::CheckPrimitiveType(next_node, prim::kPrimLoad)) { |
|
|
|
EnqueueActiveNodes(next_node, visit_queue, visited_nodes); |
|
|
|
} else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) { |
|
|
|
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); |
|
|
|
visit_queue->push(next_node); |
|
|
|
} else { |
|
|
|
@@ -217,10 +201,7 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &active_node : active_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(active_node); |
|
|
|
MS_LOG(DEBUG) << "Visit node:" << active_node->DebugString(); |
|
|
|
visit_queue->push(active_node); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -233,7 +214,7 @@ void KernelGraph::SetExecOrderByDefault() { |
|
|
|
std::queue<AnfNodePtr> zero_input_nodes; |
|
|
|
std::stack<AnfNodePtr> delay_comm_stack; |
|
|
|
std::queue<AnfNodePtr> communication_descendants; |
|
|
|
std::map<std::string, std::string> optimized_comm_group; |
|
|
|
std::string optimized_comm_group; |
|
|
|
while (!seed_nodes.empty() || !delay_comm_stack.empty()) { |
|
|
|
// seed nodes first, then delay comm nodes |
|
|
|
if (seed_nodes.empty()) { |
|
|
|
@@ -262,9 +243,13 @@ void KernelGraph::SetExecOrderByDefault() { |
|
|
|
} |
|
|
|
// delay execute comm ops that need optimize |
|
|
|
bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node); |
|
|
|
bool optimize_comm = is_fused_comm; |
|
|
|
if (optimize_comm) { |
|
|
|
optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group); |
|
|
|
bool optimize_comm = false; |
|
|
|
if (is_fused_comm && optimized_comm_group.empty()) { |
|
|
|
auto node_group = GetNodeGroup(node); |
|
|
|
if (node_group.find(kSyncBnGroup) == string::npos) { |
|
|
|
optimized_comm_group = node_group; |
|
|
|
optimize_comm = true; |
|
|
|
} |
|
|
|
} |
|
|
|
if (optimize_comm) { |
|
|
|
while (!delay_comm_stack.empty()) { |
|
|
|
|