diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index 621dc3e1ce..35ce7e53b2 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -24,6 +24,7 @@ #include "abstract/abstract_function.h" #include "utils/flags.h" +#include "utils/utils.h" namespace mindspore { /* namespace to support opt */ @@ -32,6 +33,20 @@ using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunctionPtr; +bool WithRecomputedScope(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto full_name_with_scope = node->fullname_with_scope(); + return full_name_with_scope.find(kAttrRecompute) == 0; +} + +bool IsSetRecomputed(const CNodePtr &a, const CNodePtr &b) { + return (WithRecomputedScope(a) && !a->HasAttr(kAttrNeedCseAfterRecompute)) || + (WithRecomputedScope(b) && !b->HasAttr(kAttrNeedCseAfterRecompute)); +} + BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) { MS_EXCEPTION_IF_NULL(node); auto node_abs = node->abstract(); @@ -83,7 +98,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { } else if (node->isa()) { h = node->hash(); } else { - MS_LOG(ERROR) << "Unknow node type"; + MS_LOG(ERROR) << "Unknown node type"; } hashes[node] = h; @@ -142,6 +157,10 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec } else if (main->isa() && node->isa()) { auto c_main = main->cast(); auto c_node = node->cast(); + // Not do cse for the node set recompute before the recompute pass. + if (IsSetRecomputed(c_main, c_node)) { + return false; + } // When appsame is true, check if has side effect, do not merge. if (check_side_effect && HasSideEffect(main)) { return false; diff --git a/mindspore/ccsrc/frontend/optimizer/recompute.cc b/mindspore/ccsrc/frontend/optimizer/recompute.cc index d4f7f23f86..4b2aeeea39 100644 --- a/mindspore/ccsrc/frontend/optimizer/recompute.cc +++ b/mindspore/ccsrc/frontend/optimizer/recompute.cc @@ -25,12 +25,12 @@ #include #include "ir/func_graph.h" #include "mindspore/core/base/core_ops.h" +#include "utils/utils.h" namespace mindspore { namespace opt { namespace { constexpr auto kGradientsFlag = "Gradients"; -constexpr auto kAttrRecompute = "recompute"; bool IsBpropNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { @@ -339,6 +339,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod auto recomputed_node = graph->NewCNode(new_inputs); MS_EXCEPTION_IF_NULL(recomputed_node); recomputed_node->AddAttr("duplicated", MakeValue(true)); + recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); recomputed_node->set_abstract(origin_node->abstract()); recomputed_node->set_scope(origin_node->scope()); origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); @@ -415,6 +416,12 @@ void InsertRecomputedNodes(const FuncGraphPtr &graph) { DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs, &origin_to_recomputed_nodes); } + // Set need cse attr for doing cse after recompute. + for (const auto &node : orders) { + if (WithRecomputedScope(node)) { + node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); + } + } } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 2a8a2b9bc5..b062651755 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -302,6 +302,11 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { return map; } +OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) { + OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}}); + return map; +} + static std::unordered_map> g_pass_opts = {}; void InitOpt(const ResourcePtr &res) { @@ -323,6 +328,8 @@ void InitOpt(const ResourcePtr &res) { g_pass_opts["opt_grad_epilogue"] = Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); + g_pass_opts["opt_after_recompute"] = + Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass)); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { @@ -367,6 +374,7 @@ bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } +bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); } bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } @@ -525,7 +533,8 @@ std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStru {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"add_cache_embedding", AddCacheEmbeddingPass}, {"add_control_depend", AddControlDependPass}, - {"add_recomputation", AddRecomputationPass}}; + {"add_recomputation", AddRecomputationPass}, + {"cse_after_recomputation", OptAfterRecomputeGroup}}; std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup}, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 236771ac63..5971ecc91e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -380,6 +380,8 @@ constexpr auto kAttrPadMode = "pad_mode"; constexpr auto kAttrPad = "pad"; constexpr auto kAttrPadding = "padding"; constexpr auto kAttrIsGrad = "is_grad"; +constexpr auto kAttrRecompute = "recompute"; +constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; // attr value constexpr auto kValueTargetSwitch = "target_switch";