From: @tronzhang Reviewed-by: @ryanww,@gaoxiong1 Signed-off-by: @gaoxiong1tags/v1.1.0
| @@ -16,6 +16,7 @@ | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| @@ -26,13 +27,15 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node) { | |||
| bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std::vector<PrimitivePtr> &black_list) { | |||
| auto main_primitive = AnfAlgo::GetCNodePrimitive(main); | |||
| auto node_primitive = AnfAlgo::GetCNodePrimitive(node); | |||
| if (main_primitive != nullptr && node_primitive != nullptr) { | |||
| // Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op | |||
| // alone can prevent some redundant output case (input -> reshape -> output). | |||
| if (main_primitive->name() != node_primitive->name() || IsPrimitiveCNode(node, prim::kPrimReshape)) { | |||
| if (main_primitive->name() != node_primitive->name() || | |||
| std::any_of(black_list.begin(), black_list.end(), | |||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) { | |||
| return false; | |||
| } | |||
| @@ -125,12 +128,12 @@ bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const | |||
| return false; | |||
| } | |||
| } | |||
| return IsCNodePrimitveEqual(c_main, c_node); | |||
| return IsCNodePrimitveEqual(c_main, c_node, black_list_); | |||
| } | |||
| bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>(); | |||
| auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>(black_list_); | |||
| return graphkernel_backend_cse->Cse(func_graph, func_graph->manager()); | |||
| } | |||
| } // namespace opt | |||
| @@ -13,27 +13,35 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_ | |||
| #include <vector> | |||
| #include "backend/optimizer/pass/common_subexpression_elimination.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class GraphKernelCSE : public Pass { | |||
| public: | |||
| GraphKernelCSE() : Pass("graph_kernel_cse") {} | |||
| explicit GraphKernelCSE(const std::vector<PrimitivePtr> &black_list = {}) | |||
| : Pass("graph_kernel_cse"), black_list_(black_list) {} | |||
| ~GraphKernelCSE() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| std::vector<PrimitivePtr> black_list_; | |||
| }; | |||
| class GraphKernelBackendCSE : public BackendCSE { | |||
| public: | |||
| GraphKernelBackendCSE() = default; | |||
| explicit GraphKernelBackendCSE(const std::vector<PrimitivePtr> &black_list = {}) : black_list_(black_list) {} | |||
| ~GraphKernelBackendCSE() override = default; | |||
| bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const override; | |||
| bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const override; | |||
| private: | |||
| std::vector<PrimitivePtr> black_list_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_ | |||
| @@ -34,7 +34,7 @@ namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) { | |||
| std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape}; | |||
| std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape, prim::kPrimCast}; | |||
| auto &users = mng->node_users(); | |||
| return std::any_of(shape_ops.begin(), shape_ops.end(), | |||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) && | |||
| @@ -120,7 +120,9 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| pm->AddPass(std::make_shared<opt::AdamFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>()); | |||
| pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all")); | |||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||
| pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all")); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum")); | |||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | |||
| @@ -165,15 +167,17 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); | |||
| std::vector<PrimitivePtr> black_list = {prim::kPrimReshape, prim::kPrimCast}; | |||
| pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); | |||
| pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>()); | |||
| pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); | |||
| pm->AddPass(std::make_shared<opt::CompositeOpsFusion>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list)); | |||
| pm->AddPass(std::make_shared<opt::ArithmeticSimplify>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list)); | |||
| pm->AddPass(std::make_shared<opt::TensorPromotion>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelSplitter>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); | |||
| // After Simplify and Splitter, a lot of redundant getitem/maketuple | |||
| // will be exposed, use GetitemTuple Pass to delete them. | |||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | |||