From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doutags/v1.2.0-rc1
| @@ -66,36 +66,23 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode | |||||
| return EXCLUDE; | return EXCLUDE; | ||||
| } | } | ||||
| // The GetItem node should be fused with its real input and users. | |||||
| // The GetItem node should be fused with its real input. | |||||
| // If its real input is not in the fuse_list, the GetItem should be excluded. | // If its real input is not in the fuse_list, the GetItem should be excluded. | ||||
| AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { | AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { | ||||
| if (fused_op.empty()) return AnfNodePtrList(); | if (fused_op.empty()) return AnfNodePtrList(); | ||||
| std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); | std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); | ||||
| auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; | auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; | ||||
| auto mng = fused_op[0]->func_graph()->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| bool changed = true; | bool changed = true; | ||||
| while (changed) { | while (changed) { | ||||
| changed = false; | changed = false; | ||||
| AnfNodePtrList remove_list; | AnfNodePtrList remove_list; | ||||
| for (auto getitem : fused_op_set) { | for (auto getitem : fused_op_set) { | ||||
| if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; | if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; | ||||
| // GetItem should be fused with its real input. | // GetItem should be fused with its real input. | ||||
| auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); | auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); | ||||
| if (check_include(prev_node) == EXCLUDE) { | if (check_include(prev_node) == EXCLUDE) { | ||||
| remove_list.push_back(getitem); | remove_list.push_back(getitem); | ||||
| break; | |||||
| } | |||||
| // GetItem should be fused with its all users. | |||||
| const auto &users = mng->node_users()[getitem]; | |||||
| if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) { | |||||
| return check_include(user.first) == EXCLUDE; | |||||
| })) { | |||||
| remove_list = DeepLinkedGraphSearch(getitem, check_include); | |||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| if (!remove_list.empty()) { | if (!remove_list.empty()) { | ||||
| @@ -753,7 +753,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() { | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | ||||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, | prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, | ||||
| prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | ||||
| prim::kPrimCast}; | |||||
| prim::kPrimCast, prim::kPrimExpandDims}; | |||||
| #else | #else | ||||
| std::vector<PrimitivePtr> fusible_basic_ops; | std::vector<PrimitivePtr> fusible_basic_ops; | ||||
| #endif | #endif | ||||
| @@ -33,14 +33,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) { | |||||
| std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape, prim::kPrimCast}; | |||||
| auto &users = mng->node_users(); | |||||
| return std::any_of(shape_ops.begin(), shape_ops.end(), | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) && | |||||
| users[node].size() > 1; | |||||
| } | |||||
| AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { | AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { | ||||
| auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph()); | auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph()); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| @@ -75,7 +67,14 @@ void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { | |||||
| bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { | |||||
| auto &users = mng->node_users(); | |||||
| return users[node].size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), [&node](const PrimitivePtr &prim) { | |||||
| return IsPrimitiveCNode(node, prim); | |||||
| }); | |||||
| } | |||||
| bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| auto mng = func_graph->manager(); | auto mng = func_graph->manager(); | ||||
| if (mng == nullptr) { | if (mng == nullptr) { | ||||
| @@ -96,5 +95,15 @@ bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { | |||||
| mng->KeepRoots({func_graph}); | mng->KeepRoots({func_graph}); | ||||
| return changed; | return changed; | ||||
| } | } | ||||
| bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { | |||||
| bool result = false; | |||||
| bool changed; | |||||
| do { | |||||
| changed = Process(func_graph); | |||||
| result |= changed; | |||||
| } while (changed); | |||||
| return result; | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,7 @@ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_ | #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_ | ||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_ | #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "backend/optimizer/common/pass.h" | #include "backend/optimizer/common/pass.h" | ||||
| @@ -23,9 +24,15 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| class ShapeOpsSplitter : public Pass { | class ShapeOpsSplitter : public Pass { | ||||
| public: | public: | ||||
| ShapeOpsSplitter() : Pass("shape_ops_splitter") {} | |||||
| explicit ShapeOpsSplitter(const std::vector<PrimitivePtr> &shape_ops) | |||||
| : Pass("shape_ops_splitter"), shape_ops_(shape_ops) {} | |||||
| ~ShapeOpsSplitter() override = default; | ~ShapeOpsSplitter() override = default; | ||||
| bool Run(const FuncGraphPtr &func_graph); | bool Run(const FuncGraphPtr &func_graph); | ||||
| private: | |||||
| bool Process(const FuncGraphPtr &func_graph); | |||||
| bool IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng); | |||||
| const std::vector<PrimitivePtr> &shape_ops_; | |||||
| }; | }; | ||||
| using ShapeOpsSplitterPtr = std::shared_ptr<ShapeOpsSplitter>; | using ShapeOpsSplitterPtr = std::shared_ptr<ShapeOpsSplitter>; | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -177,14 +177,14 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||||
| } | } | ||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); | auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); | ||||
| std::vector<PrimitivePtr> black_list = {prim::kPrimReshape, prim::kPrimCast}; | |||||
| std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; | |||||
| pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); | pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); | ||||
| pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>()); | |||||
| pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops)); | |||||
| pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); | pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>()); | pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>()); | ||||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list)); | |||||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops)); | |||||
| pm->AddPass(std::make_shared<opt::ArithmeticSimplify>()); | pm->AddPass(std::make_shared<opt::ArithmeticSimplify>()); | ||||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list)); | |||||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops)); | |||||
| pm->AddPass(std::make_shared<opt::TensorPromotion>()); | pm->AddPass(std::make_shared<opt::TensorPromotion>()); | ||||
| pm->AddPass(std::make_shared<opt::GraphKernelSplitter>()); | pm->AddPass(std::make_shared<opt::GraphKernelSplitter>()); | ||||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); | pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); | ||||