From f15cb6b7c9ad75224c902d3255d687bde1185d00 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Tue, 16 Jun 2020 19:47:30 +0800 Subject: [PATCH] Add sort by index for each group of AllReduce --- .../ascend/ir_fission/topk_split.cc | 34 +++++++++--- .../pass/communication_op_fusion.cc | 7 +++ mindspore/ccsrc/utils/utils.h | 1 + .../pass/allreduce_fusion_test.cc | 53 +++++++++++++++++-- .../pre_activate/ir_fusion_test.py | 11 ++++ 5 files changed, 95 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc index 95bcb9f210..1cace41fc4 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc @@ -91,6 +91,30 @@ kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); return builder.Build(); } + +bool CheckInputNamesSize(const CNodePtr &cnode) { + auto input_names_vec = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); + if (input_names_vec.size() < kTopkIndexK + 1) { + MS_LOG(INFO) << "The input k of topk has been converted to attr"; + return false; + } + return true; +} + +bool CheckOutputShape(const AnfNodePtr &node) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + if (shape.empty()) { + MS_LOG(INFO) << "The output shape of topk to split must not be empty"; + return false; + } + auto last_dim = shape[shape.size() - 1]; + const size_t kMaxFloat16 = 65500; + if (last_dim > kMaxFloat16) { + MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops."; + return false; + } + return true; +} } // namespace const BaseRef TopKSplit::DefinePattern() const { @@ -107,16 +131,10 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod // set value node as topk's input auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - auto input_names_vec = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); - if (input_names_vec.size() < kTopkIndexK + 1) { - MS_LOG(INFO) << "The input k of topk has been converted to attr"; + if (!CheckInputNamesSize(cnode)) { return nullptr; } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - auto last_dim = shape[shape.size() - 1]; - const size_t kMaxFloat16 = 65500; - if (last_dim > kMaxFloat16) { - MS_LOG(INFO) << "The last dim is more than 65500, switch to aicpu ops."; + if (!CheckOutputShape(cnode)) { return nullptr; } // Copy a new node to check supported. diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc index fc878dd881..aa4690abcb 100644 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc @@ -253,6 +253,13 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { if (it.second.communication_op_nodes.size() <= 1) { continue; } + auto first_node = it.second.communication_op_nodes[0]; + if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr(first_node, kAttrIndex) > 0) { + std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(), + [](const CNodePtr &a, const CNodePtr &b) { + return AnfAlgo::GetNodeAttr(a, kAttrIndex) < AnfAlgo::GetNodeAttr(b, kAttrIndex); + }); + } size_t segment_num = 0; std::vector segment_index; if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index f337dca329..0ccf3ba524 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -209,6 +209,7 @@ constexpr auto kAttrRecordEvent = "record_event"; constexpr auto kAttrWaitEvent = "wait_event"; constexpr auto kAttrRecordEventStream = "record_event_stream"; constexpr auto kAttrWaitEventStream = "wait_event_stream"; +constexpr auto kAttrIndex = "index"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc index 3208e0b48e..077a9f0723 100644 --- a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc @@ -58,7 +58,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_all) { builder.SetProcessor(kernel::Processor::AICORE); builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); auto node_list = TopoSort(func_graph->get_return()); - for (auto& node : node_list) { + for (auto &node : node_list) { if (node == nullptr) { continue; } @@ -99,7 +99,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_group) { builder.SetProcessor(kernel::Processor::AICORE); builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); auto node_list = TopoSort(func_graph->get_return()); - for (auto& node : node_list) { + for (auto &node : node_list) { if (node == nullptr) { continue; } @@ -141,7 +141,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) { builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); auto node_list = TopoSort(func_graph->get_return()); int count = 0; - for (auto& node : node_list) { + for (auto &node : node_list) { if (node == nullptr) { continue; } @@ -171,5 +171,52 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) { EXPECT_NE(g_after, nullptr); EXPECT_TRUE(CheckEqualGraph(new_graph, g_after)); } + +TEST_F(TestHWAllReduceFusion, test_fusion_sorted) { + getPyFun_.SetDoResolve(true); + FuncGraphPtr g = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "before"); + EXPECT_NE(g, nullptr); + std::vector shp_x{1, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; + auto func_graph = GetKernelGraph(g, args_spec_list); + EXPECT_NE(func_graph, nullptr); + auto ret = func_graph->get_return(); + auto make_tuple = ret->input(1); + auto make_tuple1 = make_tuple->cast()->input(1)->cast(); + for (size_t i = 1; i < make_tuple1->inputs().size(); ++i) { + AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(SizeToInt(i)), make_tuple1->input(i)); + } + // set kernel build info + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetInputsFormat({"NC1HWC0"}); + builder.SetOutputsFormat({"NC1HWC0"}); + builder.SetInputsDeviceType({kFloat32->type_id()}); + builder.SetOutputsDeviceType({kFloat32->type_id()}); + builder.SetFusionType(kernel::FusionType::ELEMWISE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (node == nullptr) { + continue; + } + if ((node->isa() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) || node->isa()) { + node->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get()); + } + } + // do all reduce fusion + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(func_graph); + EXPECT_NE(new_graph, nullptr); + // check result + FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "after1"); + EXPECT_NE(g_after, nullptr); + EXPECT_TRUE(CheckEqualGraph(new_graph, g_after)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py index 3c59dd0d00..195402c92b 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py @@ -140,6 +140,17 @@ def test_all_reduce_fusion_all(tag): res = make_tuple(y1, y2, y3, y4, y5) return make_tuple(res) + @fns + def after1(x1, x2, x3, x4, x5): + ar = allreduce(x1, x2, x3, x4, x5) + y1 = tuple_getitem(ar, 0) + y2 = tuple_getitem(ar, 1) + y3 = tuple_getitem(ar, 2) + y4 = tuple_getitem(ar, 3) + y5 = tuple_getitem(ar, 4) + res = make_tuple(y1, y2, y3, y4, y5) + return make_tuple(res) + return fns[tag]