From 8af78cd5ce4f1da0fd589bede2c39cfa339f6818 Mon Sep 17 00:00:00 2001 From: dayschan Date: Wed, 30 Dec 2020 16:15:54 +0800 Subject: [PATCH] Added ExpandDims into GPU fusion list what's more: remove one restriction of getitem in ops fusion. add a while loop for the ShapeOpsSplitter pass. add ExpandDims into shape_ops list. --- .../graph_kernel/basic_ops_fusion.cc | 15 +---------- .../graph_kernel/graph_kernel_helper.cc | 2 +- .../graph_kernel/shape_ops_splitter.cc | 27 ++++++++++++------- .../graph_kernel/shape_ops_splitter.h | 9 ++++++- .../ccsrc/backend/session/gpu_session.cc | 8 +++--- 5 files changed, 32 insertions(+), 29 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index f1057685d6..4bf455e0fe 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -66,36 +66,23 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode return EXCLUDE; } -// The GetItem node should be fused with its real input and users. +// The GetItem node should be fused with its real input. // If its real input is not in the fuse_list, the GetItem should be excluded. AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { if (fused_op.empty()) return AnfNodePtrList(); std::set fused_op_set(fused_op.begin(), fused_op.end()); auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; - auto mng = fused_op[0]->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(mng); bool changed = true; while (changed) { changed = false; AnfNodePtrList remove_list; for (auto getitem : fused_op_set) { if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; - // GetItem should be fused with its real input. auto prev_node = getitem->cast()->input(kRealInputNodeIndexInTupleGetItem); if (check_include(prev_node) == EXCLUDE) { remove_list.push_back(getitem); - break; - } - - // GetItem should be fused with its all users. - const auto &users = mng->node_users()[getitem]; - if (std::any_of(users.begin(), users.end(), [check_include](const std::pair &user) { - return check_include(user.first) == EXCLUDE; - })) { - remove_list = DeepLinkedGraphSearch(getitem, check_include); - break; } } if (!remove_list.empty()) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 8818498d3b..3f4948e9ce 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -753,7 +753,7 @@ std::vector GetFusibleOpList() { prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, - prim::kPrimCast}; + prim::kPrimCast, prim::kPrimExpandDims}; #else std::vector fusible_basic_ops; #endif diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc index 006d0e4961..4323ac582c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc @@ -33,14 +33,6 @@ namespace mindspore { namespace opt { namespace { -bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) { - std::vector shape_ops = {prim::kPrimReshape, prim::kPrimCast}; - auto &users = mng->node_users(); - return std::any_of(shape_ops.begin(), shape_ops.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) && - users[node].size() > 1; -} - AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { auto kernel_graph = std::dynamic_pointer_cast(anf_node->func_graph()); MS_EXCEPTION_IF_NULL(kernel_graph); @@ -75,7 +67,14 @@ void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { } } // namespace -bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { +bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { + auto &users = mng->node_users(); + return users[node].size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), [&node](const PrimitivePtr &prim) { + return IsPrimitiveCNode(node, prim); + }); +} + +bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); auto mng = func_graph->manager(); if (mng == nullptr) { @@ -96,5 +95,15 @@ bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { mng->KeepRoots({func_graph}); return changed; } + +bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { + bool result = false; + bool changed; + do { + changed = Process(func_graph); + result |= changed; + } while (changed); + return result; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.h index 618d5e194d..36a030dfdc 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_ #include +#include #include "ir/func_graph.h" #include "backend/optimizer/common/pass.h" @@ -23,9 +24,15 @@ namespace mindspore { namespace opt { class ShapeOpsSplitter : public Pass { public: - ShapeOpsSplitter() : Pass("shape_ops_splitter") {} + explicit ShapeOpsSplitter(const std::vector &shape_ops) + : Pass("shape_ops_splitter"), shape_ops_(shape_ops) {} ~ShapeOpsSplitter() override = default; bool Run(const FuncGraphPtr &func_graph); + + private: + bool Process(const FuncGraphPtr &func_graph); + bool IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng); + const std::vector &shape_ops_; }; using ShapeOpsSplitterPtr = std::shared_ptr; } // namespace opt diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index ba946aa876..54b9912977 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -177,14 +177,14 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ } auto optimizer = std::make_shared(); auto pm = std::make_shared("graph_kernel_pm"); - std::vector black_list = {prim::kPrimReshape, prim::kPrimCast}; + std::vector duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared(duplicated_ops)); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared(black_list)); + pm->AddPass(std::make_shared(duplicated_ops)); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared(black_list)); + pm->AddPass(std::make_shared(duplicated_ops)); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared());