From c4c69bf5e84360a52293b826637f68a04f987c1c Mon Sep 17 00:00:00 2001 From: r1chardf1d0 Date: Thu, 29 Apr 2021 18:20:40 +0800 Subject: [PATCH] [GraphKernel] clean code for graph_kernel_splitter* & add_stitch_atomic_clean_gpu* --- .../add_stitch_atomic_clean_gpu.cc | 14 ++-- .../add_stitch_atomic_clean_gpu.h | 11 +-- .../graph_kernel/graph_kernel_splitter.cc | 72 ++++++++++--------- .../graph_kernel/graph_kernel_splitter.h | 2 +- 4 files changed, 52 insertions(+), 47 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc index d5cc0c000a..9eaf0878e9 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc @@ -75,8 +75,8 @@ void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &compos AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get()); } -CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, - const AnfNodePtr &new_parameter) { +CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, + const AnfNodePtr &new_parameter) { // add inplaceassign AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true. auto inplace_assign_node = @@ -88,8 +88,8 @@ CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(con return inplace_assign_node; } -void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng) { +void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, + const FuncGraphManagerPtr &mng) { auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node); auto mng_sub = sub_graph->manager(); if (mng_sub == nullptr) { @@ -107,7 +107,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g parameter->set_abstract(new_input->abstract()); parameter->set_kernel_info(new_input->kernel_info_ptr()); - auto inplace_assign = CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter); + auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter); // Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid // elimination. @@ -116,7 +116,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g for (const auto &[user_node, index] : reduce_user_nodes) { auto user_cnode = user_node->cast(); MS_EXCEPTION_IF_NULL(user_cnode); - user_cnode->set_input(index, parameter); + user_cnode->set_input(static_cast(index), parameter); if (!connected) { std::vector> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode); if (!user_user.empty()) { @@ -135,7 +135,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_g } std::vector> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node, - const CNodePtr &target) { + const CNodePtr &target) const { auto node = inner_node->cast(); MS_EXCEPTION_IF_NULL(node); auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h index 4e8ca8d27d..bb2f3ee9c4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h @@ -34,11 +34,12 @@ class StitchAtomicCleanInsertter : public AtomicCleanInsertter { bool Run(const FuncGraphPtr &func_graph) override; private: - void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); - CNodePtr CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); - std::vector> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target); - void ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng); + void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) override; + CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); + std::vector> FindInnerCNodeUsers(const AnfNodePtr &inner_node, + const CNodePtr &target) const; + void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, + const FuncGraphManagerPtr &mng); bool IsStitchWithAtomic(const AnfNodePtr &anf_node); AnfNodePtr stitch_node_{nullptr}; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc index 02c354aafd..aa21e45ddd 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc @@ -34,7 +34,7 @@ namespace mindspore { namespace opt { namespace { -void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function callback) { +void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function &callback) { std::unordered_set visited; std::queue que; que.push(cnode); @@ -55,7 +55,7 @@ void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function callback) { +inline void TraverseFuncGraph(const FuncGraphPtr &root, const std::function &callback) { TraverseFuncGraphFromCNode(root->get_return(), callback); } @@ -68,7 +68,7 @@ class Area { if (cnode == nullptr) continue; const auto &inputs = cnode->inputs(); if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) { - spy_cnodes_.push_back(node); + spy_cnodes_.emplace_back(node); } } } @@ -126,15 +126,15 @@ class Area { size_t i = 0; for (auto &traitor : traitor_nodes_) { tuple_node_index->insert(std::make_pair(traitor, i++)); - maketuple_inputs.push_back(traitor); - abstracts.push_back(traitor->abstract()); + maketuple_inputs.emplace_back(traitor); + abstracts.emplace_back(traitor->abstract()); } auto maketuple_node = func_graph->NewCNode(maketuple_inputs); maketuple_node->set_abstract(std::make_shared(abstracts)); nodes_.insert(maketuple_node); - return_inputs.push_back(maketuple_node); + return_inputs.emplace_back(maketuple_node); } else { - return_inputs.push_back(traitor_nodes_[0]); + return_inputs.emplace_back(traitor_nodes_[0]); } auto return_node = func_graph->NewCNode(return_inputs); return_node->set_abstract(return_inputs.back()->abstract()); @@ -146,7 +146,7 @@ class Area { void AddTraitor(const AnfNodePtr &node) { if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) { - traitor_nodes_.push_back(node); + traitor_nodes_.emplace_back(node); } } @@ -185,7 +185,7 @@ class AreaGraph { // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs. // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting. void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector *main_cnodes, - std::vector *cnode_group_id, std::function expand_callback) { + std::vector *cnode_group_id, const std::function &expand_callback) { main_cnodes->clear(); main_cnodes->resize(areas_.size(), nullptr); @@ -228,7 +228,7 @@ class AreaGraph { if (u == v) continue; areas_[u].AddTraitor(in_node); if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) { - edge_prev_[v].push_back(u); + edge_prev_[v].emplace_back(u); } } } @@ -252,7 +252,7 @@ class AreaGraph { while (!que.empty()) { size_t u = que.front(); que.pop(); - topo_order_.push_back(u); + topo_order_.emplace_back(u); for (size_t i : edge_prev_[u]) { if (--out_degree[i] == 0) que.push(i); } @@ -280,9 +280,9 @@ class AreaGraph { TraceGuard g_sub(std::make_shared(main_cnodes[input_area]->debug_info())); auto getitem_node = main_func_graph->NewCNode(getitem_inputs); getitem_node->set_abstract(main_cnodes[input_area]->abstract()); - main_cnode_inputs.push_back(getitem_node); + main_cnode_inputs.emplace_back(getitem_node); } else { - main_cnode_inputs.push_back(main_cnodes[input_area]); + main_cnode_inputs.emplace_back(main_cnodes[input_area]); } } auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs); @@ -293,7 +293,7 @@ class AreaGraph { void SortCNodes(std::vector *main_cnodes) { std::vector main_cnodes_sorted; std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted), - [main_cnodes](int index) { return main_cnodes->at(index); }); + [main_cnodes](size_t index) { return main_cnodes->at(index); }); *main_cnodes = std::move(main_cnodes_sorted); } @@ -309,17 +309,20 @@ class AreaGraph { std::unordered_map node_index_in_returned_tuple_; }; +class SplitSchemer { + public: + SplitSchemer() = default; + virtual ~SplitSchemer() = default; + virtual bool Split(const FuncGraphPtr &func_graph) = 0; + virtual bool NeedInline(size_t group_id) const { return false; } + const std::vector &split_plan() const { return split_plan_; } + + protected: + std::vector split_plan_; +}; + class Splitter { public: - class SplitSchemer { - public: - virtual bool Split(const FuncGraphPtr &func_graph) = 0; - virtual bool NeedInline(size_t group_id) const { return false; } - const std::vector &split_plan() const { return split_plan_; } - - protected: - std::vector split_plan_; - }; using SplitSchemerPtr = std::shared_ptr; using SplitterPtr = std::shared_ptr; @@ -345,14 +348,14 @@ class Splitter { return true; } - static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) { + static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) { MS_EXCEPTION_IF_NULL(main_cnode); MS_EXCEPTION_IF_NULL(main_cnode->func_graph()); MS_EXCEPTION_IF_NULL(split_schemer); return std::make_shared(main_cnode, split_schemer); } - Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) + Splitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {} ~Splitter() = default; @@ -376,7 +379,7 @@ class Splitter { void BindFuncGraph() { for (const auto &cnode : new_subgraph_cnodes_) { auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - auto callback = [&sub_func_graph, this](const AnfNodePtr &node) { + auto callback = [&sub_func_graph](const AnfNodePtr &node) { if (!node->isa()) { node->set_func_graph(sub_func_graph); } @@ -425,7 +428,7 @@ class Splitter { } } if (AnfAlgo::IsRealKernel(node)) { - inlined_nodes_.push_back(node); + inlined_nodes_.emplace_back(node); } } } @@ -454,7 +457,7 @@ class Splitter { if (i + 1 == new_subgraph_cnodes_.size()) { replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back(); } - tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]); + tmp_subgraph_cnodes.emplace_back(new_subgraph_cnodes_[i]); } } new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes); @@ -543,8 +546,9 @@ class Splitter { std::unordered_map param_to_main_graph_node_map_; }; -class CostModelSplitSchemer : public Splitter::SplitSchemer { +class CostModelSplitSchemer : public SplitSchemer { public: + virtual ~CostModelSplitSchemer() = default; bool Split(const FuncGraphPtr &func_graph) override { if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node."; @@ -620,7 +624,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; return false; } - split_plan_.push_back(std::move(res_graph)); + split_plan_.emplace_back(std::move(res_graph)); } // ops to be inlined. @@ -687,14 +691,14 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { if (IsValidKernelNode(output)) { auto group_id = node_group_[ret_node] = node_group_[output]; - split_plan_[group_id].push_back(ret_node); + split_plan_[group_id].emplace_back(ret_node); return; } // assign the make_tuple node to a new group. if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) { auto group_id = split_plan_.size(); - split_plan_.push_back({output, ret_node}); - need_inline_.push_back(1); + split_plan_.emplace_back(AnfNodePtrList{output, ret_node}); + need_inline_.emplace_back(1); node_group_[ret_node] = node_group_[output] = group_id; return; } @@ -711,7 +715,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { auto iter = node_group_.find(input); if (iter != node_group_.end()) { node_group_[node] = iter->second; - split_plan_[iter->second].push_back(node); + split_plan_[iter->second].emplace_back(node); found = true; break; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h index 90ded3ba57..cab14e3376 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h @@ -25,7 +25,7 @@ class GraphKernelSplitter : public Pass { public: GraphKernelSplitter() : Pass("graph_kernel_splitter") {} ~GraphKernelSplitter() override = default; - bool Run(const FuncGraphPtr &func_graph); + bool Run(const FuncGraphPtr &func_graph) override; }; using GraphKernelSplitterPtr = std::shared_ptr; } // namespace opt