| @@ -91,6 +91,30 @@ kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { | |||||
| builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); | builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); | ||||
| return builder.Build(); | return builder.Build(); | ||||
| } | } | ||||
| bool CheckInputNamesSize(const CNodePtr &cnode) { | |||||
| auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames); | |||||
| if (input_names_vec.size() < kTopkIndexK + 1) { | |||||
| MS_LOG(INFO) << "The input k of topk has been converted to attr"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool CheckOutputShape(const AnfNodePtr &node) { | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||||
| if (shape.empty()) { | |||||
| MS_LOG(INFO) << "The output shape of topk to split must not be empty"; | |||||
| return false; | |||||
| } | |||||
| auto last_dim = shape[shape.size() - 1]; | |||||
| const size_t kMaxFloat16 = 65500; | |||||
| if (last_dim > kMaxFloat16) { | |||||
| MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops."; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| const BaseRef TopKSplit::DefinePattern() const { | const BaseRef TopKSplit::DefinePattern() const { | ||||
| @@ -107,16 +131,10 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod | |||||
| // set value node as topk's input | // set value node as topk's input | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames); | |||||
| if (input_names_vec.size() < kTopkIndexK + 1) { | |||||
| MS_LOG(INFO) << "The input k of topk has been converted to attr"; | |||||
| if (!CheckInputNamesSize(cnode)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||||
| auto last_dim = shape[shape.size() - 1]; | |||||
| const size_t kMaxFloat16 = 65500; | |||||
| if (last_dim > kMaxFloat16) { | |||||
| MS_LOG(INFO) << "The last dim is more than 65500, switch to aicpu ops."; | |||||
| if (!CheckOutputShape(cnode)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // Copy a new node to check supported. | // Copy a new node to check supported. | ||||
| @@ -253,6 +253,13 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { | |||||
| if (it.second.communication_op_nodes.size() <= 1) { | if (it.second.communication_op_nodes.size() <= 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto first_node = it.second.communication_op_nodes[0]; | |||||
| if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr<int>(first_node, kAttrIndex) > 0) { | |||||
| std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(), | |||||
| [](const CNodePtr &a, const CNodePtr &b) { | |||||
| return AnfAlgo::GetNodeAttr<int>(a, kAttrIndex) < AnfAlgo::GetNodeAttr<int>(b, kAttrIndex); | |||||
| }); | |||||
| } | |||||
| size_t segment_num = 0; | size_t segment_num = 0; | ||||
| std::vector<size_t> segment_index; | std::vector<size_t> segment_index; | ||||
| if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { | if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { | ||||
| @@ -209,6 +209,7 @@ constexpr auto kAttrRecordEvent = "record_event"; | |||||
| constexpr auto kAttrWaitEvent = "wait_event"; | constexpr auto kAttrWaitEvent = "wait_event"; | ||||
| constexpr auto kAttrRecordEventStream = "record_event_stream"; | constexpr auto kAttrRecordEventStream = "record_event_stream"; | ||||
| constexpr auto kAttrWaitEventStream = "wait_event_stream"; | constexpr auto kAttrWaitEventStream = "wait_event_stream"; | ||||
| constexpr auto kAttrIndex = "index"; | |||||
| // attr value | // attr value | ||||
| constexpr auto kValueTargetSwitch = "target_switch"; | constexpr auto kValueTargetSwitch = "target_switch"; | ||||
| @@ -58,7 +58,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_all) { | |||||
| builder.SetProcessor(kernel::Processor::AICORE); | builder.SetProcessor(kernel::Processor::AICORE); | ||||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | ||||
| auto node_list = TopoSort(func_graph->get_return()); | auto node_list = TopoSort(func_graph->get_return()); | ||||
| for (auto& node : node_list) { | |||||
| for (auto &node : node_list) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -99,7 +99,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_group) { | |||||
| builder.SetProcessor(kernel::Processor::AICORE); | builder.SetProcessor(kernel::Processor::AICORE); | ||||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | ||||
| auto node_list = TopoSort(func_graph->get_return()); | auto node_list = TopoSort(func_graph->get_return()); | ||||
| for (auto& node : node_list) { | |||||
| for (auto &node : node_list) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -141,7 +141,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) { | |||||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | ||||
| auto node_list = TopoSort(func_graph->get_return()); | auto node_list = TopoSort(func_graph->get_return()); | ||||
| int count = 0; | int count = 0; | ||||
| for (auto& node : node_list) { | |||||
| for (auto &node : node_list) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -171,5 +171,52 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) { | |||||
| EXPECT_NE(g_after, nullptr); | EXPECT_NE(g_after, nullptr); | ||||
| EXPECT_TRUE(CheckEqualGraph(new_graph, g_after)); | EXPECT_TRUE(CheckEqualGraph(new_graph, g_after)); | ||||
| } | } | ||||
| TEST_F(TestHWAllReduceFusion, test_fusion_sorted) { | |||||
| getPyFun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "before"); | |||||
| EXPECT_NE(g, nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||||
| auto func_graph = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(func_graph, nullptr); | |||||
| auto ret = func_graph->get_return(); | |||||
| auto make_tuple = ret->input(1); | |||||
| auto make_tuple1 = make_tuple->cast<CNodePtr>()->input(1)->cast<CNodePtr>(); | |||||
| for (size_t i = 1; i < make_tuple1->inputs().size(); ++i) { | |||||
| AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(SizeToInt(i)), make_tuple1->input(i)); | |||||
| } | |||||
| // set kernel build info | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| builder.SetInputsFormat({"NC1HWC0"}); | |||||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||||
| builder.SetInputsDeviceType({kFloat32->type_id()}); | |||||
| builder.SetOutputsDeviceType({kFloat32->type_id()}); | |||||
| builder.SetFusionType(kernel::FusionType::ELEMWISE); | |||||
| builder.SetProcessor(kernel::Processor::AICORE); | |||||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||||
| auto node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if ((node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) || node->isa<Parameter>()) { | |||||
| node->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get()); | |||||
| } | |||||
| } | |||||
| // do all reduce fusion | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(func_graph); | |||||
| EXPECT_NE(new_graph, nullptr); | |||||
| // check result | |||||
| FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "after1"); | |||||
| EXPECT_NE(g_after, nullptr); | |||||
| EXPECT_TRUE(CheckEqualGraph(new_graph, g_after)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -140,6 +140,17 @@ def test_all_reduce_fusion_all(tag): | |||||
| res = make_tuple(y1, y2, y3, y4, y5) | res = make_tuple(y1, y2, y3, y4, y5) | ||||
| return make_tuple(res) | return make_tuple(res) | ||||
| @fns | |||||
| def after1(x1, x2, x3, x4, x5): | |||||
| ar = allreduce(x1, x2, x3, x4, x5) | |||||
| y1 = tuple_getitem(ar, 0) | |||||
| y2 = tuple_getitem(ar, 1) | |||||
| y3 = tuple_getitem(ar, 2) | |||||
| y4 = tuple_getitem(ar, 3) | |||||
| y5 = tuple_getitem(ar, 4) | |||||
| res = make_tuple(y1, y2, y3, y4, y5) | |||||
| return make_tuple(res) | |||||
| return fns[tag] | return fns[tag] | ||||