| @@ -91,6 +91,30 @@ kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { | |||
| builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); | |||
| return builder.Build(); | |||
| } | |||
| bool CheckInputNamesSize(const CNodePtr &cnode) { | |||
| auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames); | |||
| if (input_names_vec.size() < kTopkIndexK + 1) { | |||
| MS_LOG(INFO) << "The input k of topk has been converted to attr"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool CheckOutputShape(const AnfNodePtr &node) { | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||
| if (shape.empty()) { | |||
| MS_LOG(INFO) << "The output shape of topk to split must not be empty"; | |||
| return false; | |||
| } | |||
| auto last_dim = shape[shape.size() - 1]; | |||
| const size_t kMaxFloat16 = 65500; | |||
| if (last_dim > kMaxFloat16) { | |||
| MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| const BaseRef TopKSplit::DefinePattern() const { | |||
| @@ -107,16 +131,10 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod | |||
| // set value node as topk's input | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames); | |||
| if (input_names_vec.size() < kTopkIndexK + 1) { | |||
| MS_LOG(INFO) << "The input k of topk has been converted to attr"; | |||
| if (!CheckInputNamesSize(cnode)) { | |||
| return nullptr; | |||
| } | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||
| auto last_dim = shape[shape.size() - 1]; | |||
| const size_t kMaxFloat16 = 65500; | |||
| if (last_dim > kMaxFloat16) { | |||
| MS_LOG(INFO) << "The last dim is more than 65500, switch to aicpu ops."; | |||
| if (!CheckOutputShape(cnode)) { | |||
| return nullptr; | |||
| } | |||
| // Copy a new node to check supported. | |||
| @@ -253,6 +253,13 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { | |||
| if (it.second.communication_op_nodes.size() <= 1) { | |||
| continue; | |||
| } | |||
| auto first_node = it.second.communication_op_nodes[0]; | |||
| if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr<int>(first_node, kAttrIndex) > 0) { | |||
| std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(), | |||
| [](const CNodePtr &a, const CNodePtr &b) { | |||
| return AnfAlgo::GetNodeAttr<int>(a, kAttrIndex) < AnfAlgo::GetNodeAttr<int>(b, kAttrIndex); | |||
| }); | |||
| } | |||
| size_t segment_num = 0; | |||
| std::vector<size_t> segment_index; | |||
| if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { | |||
| @@ -209,6 +209,7 @@ constexpr auto kAttrRecordEvent = "record_event"; | |||
| constexpr auto kAttrWaitEvent = "wait_event"; | |||
| constexpr auto kAttrRecordEventStream = "record_event_stream"; | |||
| constexpr auto kAttrWaitEventStream = "wait_event_stream"; | |||
| constexpr auto kAttrIndex = "index"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||
| @@ -58,7 +58,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_all) { | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto& node : node_list) { | |||
| for (auto &node : node_list) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -99,7 +99,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_group) { | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto& node : node_list) { | |||
| for (auto &node : node_list) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -141,7 +141,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) { | |||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| int count = 0; | |||
| for (auto& node : node_list) { | |||
| for (auto &node : node_list) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -171,5 +171,52 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) { | |||
| EXPECT_NE(g_after, nullptr); | |||
| EXPECT_TRUE(CheckEqualGraph(new_graph, g_after)); | |||
| } | |||
| TEST_F(TestHWAllReduceFusion, test_fusion_sorted) { | |||
| getPyFun_.SetDoResolve(true); | |||
| FuncGraphPtr g = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||
| auto func_graph = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(func_graph, nullptr); | |||
| auto ret = func_graph->get_return(); | |||
| auto make_tuple = ret->input(1); | |||
| auto make_tuple1 = make_tuple->cast<CNodePtr>()->input(1)->cast<CNodePtr>(); | |||
| for (size_t i = 1; i < make_tuple1->inputs().size(); ++i) { | |||
| AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(SizeToInt(i)), make_tuple1->input(i)); | |||
| } | |||
| // set kernel build info | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetInputsDeviceType({kFloat32->type_id()}); | |||
| builder.SetOutputsDeviceType({kFloat32->type_id()}); | |||
| builder.SetFusionType(kernel::FusionType::ELEMWISE); | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| if ((node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) || node->isa<Parameter>()) { | |||
| node->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get()); | |||
| } | |||
| } | |||
| // do all reduce fusion | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(func_graph); | |||
| EXPECT_NE(new_graph, nullptr); | |||
| // check result | |||
| FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "after1"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| EXPECT_TRUE(CheckEqualGraph(new_graph, g_after)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -140,6 +140,17 @@ def test_all_reduce_fusion_all(tag): | |||
| res = make_tuple(y1, y2, y3, y4, y5) | |||
| return make_tuple(res) | |||
| @fns | |||
| def after1(x1, x2, x3, x4, x5): | |||
| ar = allreduce(x1, x2, x3, x4, x5) | |||
| y1 = tuple_getitem(ar, 0) | |||
| y2 = tuple_getitem(ar, 1) | |||
| y3 = tuple_getitem(ar, 2) | |||
| y4 = tuple_getitem(ar, 3) | |||
| y5 = tuple_getitem(ar, 4) | |||
| res = make_tuple(y1, y2, y3, y4, y5) | |||
| return make_tuple(res) | |||
| return fns[tag] | |||