diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc index ecc590f557..f753e8fba7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include "base/core_ops.h" #include "ir/tensor.h" @@ -304,27 +305,52 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo user_cnode->set_input(index, depend_cnode); } -void AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &pre_node, - const AnfNodePtr &post_node, const FuncGraphManagerPtr &mng) { - // Collect use dependencies firstly. - auto post_users = mng->node_users()[post_node]; - +AnfNodePtr AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, + const AnfNodePtr &behind_node, const AnfNodePtr &patron_node) { // Create control depend, first input is composite op, second is user - AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), pre_node, post_node}; + AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node}; auto control_depend_cnode = main_graph->NewCNode(cd_inputs); main_graph->AddNode(control_depend_cnode); // Create depend node to hold new control depend node. - AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), post_node, control_depend_cnode}; + AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), patron_node, control_depend_cnode}; auto depend_cnode = main_graph->NewCNode(d_inputs); - depend_cnode->set_abstract(post_node->abstract()); + depend_cnode->set_abstract(patron_node->abstract()); main_graph->AddNode(depend_cnode); - for (const auto &[user_node, index] : post_users) { - auto user_cnode = user_node->cast(); - MS_EXCEPTION_IF_NULL(user_cnode); - user_cnode->set_input(index, depend_cnode); + return depend_cnode; +} + +std::tuple AtomicCleanInsertter::FindPatronNode(const KernelGraphPtr &main_graph) { + auto mng = main_graph->manager(); + if (mng == nullptr) { + mng = Manage(main_graph, true); + main_graph->set_manager(mng); } + + AnfNodePtr patron_node; + + auto return_cnode = main_graph->get_return()->cast(); + MS_EXCEPTION_IF_NULL(return_cnode); + auto output_node = return_cnode->input(kFirstDataInputIndex); + if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) { + auto output_cnode = output_node->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + patron_node = output_cnode->input(kFirstDataInputIndex); + } else { + patron_node = output_node; + } + + auto &user_nodes = mng->node_users()[patron_node]; + auto user = user_nodes.begin(); + return std::make_tuple(patron_node, user->first, user->second); +} + +void AtomicCleanInsertter::PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, + int index) { + auto patron_user_cnode = patron_user->cast(); + MS_EXCEPTION_IF_NULL(patron_user_cnode); + patron_user_cnode->set_input(index, patron_node); } CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) { @@ -380,14 +406,14 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP kernel::Processor::CUDA); auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean"); new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr)); - // mng->AddFuncGraph(new_sub_graph); return broadcast_to_composite_node; } -void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) { - // 1. find users, change getitem index if needed. +std::vector > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph, + const AnfNodePtr &composite_node, + const FuncGraphManagerPtr &mng, + bool correct_index) { std::vector > reduce_user_nodes; if (real_output_num_ <= 1) { auto users = mng->node_users()[composite_node]; @@ -409,7 +435,7 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra auto item_idx = GetValue(value_node->value()); if (item_idx == static_cast(reduce_real_output_index_)) { getitem_user_nodes.push_back(node_index); - } else { + } else if (correct_index) { if (real_output_num_ > 2) { // Recorrect other getitem index. int64_t new_item_idx = CalNewIndex(item_idx, reduce_real_output_index_); @@ -431,7 +457,6 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra } } } - for (auto &pair : getitem_user_nodes) { // dirctory to find real user. auto real_users = mng->node_users()[pair.first]; @@ -439,12 +464,16 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra } } + return reduce_user_nodes; +} + +void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, + const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) { + // 1. find users, change getitem index if needed. + std::vector > reduce_user_nodes = + FindOriginCNodeUsers(main_graph, composite_node, mng, true); for (const auto &[user_node, index] : reduce_user_nodes) { // 2. set ac output as user's input. - auto user_cnode = user_node->cast(); - MS_EXCEPTION_IF_NULL(user_cnode); - user_cnode->set_input(index, broadcast_to_node); - // mng->SetEdge(user_node, index, broadcast_to_node); // 3. Make sure modified composite node running first. // * To not change the origin node's dependency relation, add ControlDepend and Depend node. // * For Return node and output node, ControlDepend node will change the order of these two node, which will may @@ -452,7 +481,10 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra if (IsPrimitiveCNode(user_node, prim::kPrimReturn) || user_node == main_graph->output()) { AddDepend(main_graph, broadcast_to_node, composite_node, user_node, index); } else { - AddControlDepend(main_graph, composite_node, user_node, mng); + auto user_cnode = user_node->cast(); + MS_EXCEPTION_IF_NULL(user_cnode); + user_cnode->set_input(index, broadcast_to_node); + to_process_order_.emplace_back(composite_node, user_node); } } } @@ -473,6 +505,26 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c // Replace origin ReduceSum's user with atomic clean output, and add control depend from composite op to user. ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, mng); + MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope() + << ", clean node: " << broadcast_to_node->fullname_with_scope(); +} + +bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node, + const FuncGraphManagerPtr &mng) { + auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false); + // If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce + // node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong! + return std::all_of(reduce_users.cbegin(), reduce_users.cend(), + [&main_graph](const std::pair &user_info) -> bool { + auto &user = user_info.first; + if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) || + IsPrimitiveCNode(user, prim::kPrimControlDepend)) && + !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { + return false; + } else { + return true; + } + }); } bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { @@ -487,7 +539,8 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { bool changed = false; auto topo_nodes = TopoSort(kernel_graph->get_return()); for (const auto &node : topo_nodes) { - if (!AnfAlgo::IsGraphKernel(node) || !CanActivateAtomicAdd(node)) { + if (!AnfAlgo::IsGraphKernel(node) || !CanActivateAtomicAdd(node) || + !IsExistStructuralObstacle(kernel_graph, node, mng)) { continue; } InsertAtomicClean(kernel_graph, node, mng); @@ -495,6 +548,14 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { } if (changed) { + if (!to_process_order_.empty()) { + auto [patron_node, patron_user, user_index] = FindPatronNode(kernel_graph); + for (const auto &[prior, behind] : to_process_order_) { + patron_node = AddControlDepend(kernel_graph, prior, behind, patron_node); + } + PostprocessForLastPatron(patron_node, patron_user, user_index); + } + mng->RemoveRoots(); mng->KeepRoots({func_graph}); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h index 731505e7db..34f0ea11fa 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h @@ -18,6 +18,8 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_GPU_H_ #include +#include +#include #include #include "backend/optimizer/common/optimizer.h" #include "backend/session/kernel_graph.h" @@ -37,18 +39,26 @@ class AtomicCleanInsertter : public Pass { void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index); - void AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &pre_node, const AnfNodePtr &post_node, - const FuncGraphManagerPtr &mng); void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); void CorrectAbstract(const AnfNodePtr &composite_node); void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type); void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng); + std::tuple FindPatronNode(const KernelGraphPtr &main_graph); + AnfNodePtr AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, + const AnfNodePtr &behind_node, const AnfNodePtr &patron_node); + void PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, int index); + std::vector> FindOriginCNodeUsers(const KernelGraphPtr &main_graph, + const AnfNodePtr &composite_node, + const FuncGraphManagerPtr &mng, bool correct_index); + bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node, + const FuncGraphManagerPtr &mng); CNodePtr atomic_add_node_{nullptr}; size_t reduce_real_output_index_{0}; size_t real_output_num_{0}; + std::vector> to_process_order_; }; using AtomicCleanInsertterPtr = std::shared_ptr; } // namespace opt