| @@ -16,6 +16,7 @@ | |||||
| #include "backend/optimizer/pass/communication_op_fusion.h" | #include "backend/optimizer/pass/communication_op_fusion.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | |||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| @@ -89,6 +90,13 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) { | |||||
| } | } | ||||
| return group + op + std::to_string(fusion); | return group + op + std::to_string(fusion); | ||||
| } | } | ||||
| void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) { | |||||
| std::set<AnfNodePtr> inputs_set(fusion_inputs.begin(), fusion_inputs.end()); | |||||
| if (inputs_set.size() < fusion_inputs.size()) { | |||||
| MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input"; | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, | bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, | ||||
| @@ -163,6 +171,7 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | ||||
| } | } | ||||
| CheckInputs(fusion_inputs); | |||||
| AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs); | AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs); | ||||
| MS_EXCEPTION_IF_NULL(fused_node); | MS_EXCEPTION_IF_NULL(fused_node); | ||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | auto kernel_info = std::make_shared<device::KernelInfo>(); | ||||
| @@ -172,9 +181,6 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr | |||||
| for (size_t idx = start_index; idx <= end_index; ++idx) { | for (size_t idx = start_index; idx <= end_index; ++idx) { | ||||
| auto cnode = communication_op_info.communication_op_nodes[idx]; | auto cnode = communication_op_info.communication_op_nodes[idx]; | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node); | |||||
| AnfAlgo::CopyNodeAttr("op", cnode, fused_node); | |||||
| AnfAlgo::CopyNodeAttr("group", cnode, fused_node); | |||||
| abstract_list.push_back(cnode->abstract()); | abstract_list.push_back(cnode->abstract()); | ||||
| } | } | ||||
| auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); | auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); | ||||
| @@ -182,6 +188,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | ||||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | MS_EXCEPTION_IF_NULL(abstract_tuple); | ||||
| fused_node->set_abstract(abstract_tuple); | fused_node->set_abstract(abstract_tuple); | ||||
| AnfAlgo::CopyNodeAttr("fusion", communication_op_info.communication_op_nodes[end_index], fused_node); | |||||
| AnfAlgo::CopyNodeAttr("op", communication_op_info.communication_op_nodes[end_index], fused_node); | |||||
| AnfAlgo::CopyNodeAttr("group", communication_op_info.communication_op_nodes[end_index], fused_node); | |||||
| return fused_node; | return fused_node; | ||||
| } | } | ||||