diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc index 3d6a8719c4..5474bb5c1e 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.cc +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -184,9 +184,6 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con for (auto &substitution : list_) { auto res = DoTransform(optimizer, node, substitution); if (res != nullptr) { - if (is_once_) { - return true; - } change = true; changes = true; node = res; @@ -228,9 +225,6 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons bool change = false; auto res = DoTransform(optimizer, node, substitution); if (res != nullptr) { - if (is_once_) { - return true; - } change = true; changes = true; node = res; @@ -316,9 +310,13 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize : kOptTraverseFromSubstitutionsToIR); if (traverse_mode == kOptTraverseFromIRToSubstitutions && MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && - optimizer->traverse_nodes_first()) { + optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) { + MS_LOG(DEBUG) << "IR >> SUB, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_" + << optimizer->CurPass_.name; changes = ApplyIRToSubstitutions(optimizer, func_graph); } else { + MS_LOG(DEBUG) << "SUB >> IR, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_" + << optimizer->CurPass_.name; changes = ApplySubstitutionsToIR(optimizer, func_graph); } return changes; diff --git a/mindspore/ccsrc/frontend/optimizer/opt.h b/mindspore/ccsrc/frontend/optimizer/opt.h index 7c9f2f5e69..01f21d5df6 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.h +++ b/mindspore/ccsrc/frontend/optimizer/opt.h @@ -63,8 +63,9 @@ enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptT class SubstitutionList { public: - explicit SubstitutionList(const std::vector &patterns, bool is_once = false) - : list_(patterns), is_once_(is_once) {} + explicit SubstitutionList(const std::vector &patterns, bool is_once = false, + bool global_sensitive = false) + : list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {} ~SubstitutionList() = default; bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; @@ -77,6 +78,7 @@ class SubstitutionList { std::vector list_; // a flag to mark this list of Substitution can only be executed only once bool is_once_; + bool global_sensitive_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/optimizer.h b/mindspore/ccsrc/frontend/optimizer/optimizer.h index 60a7482ca1..89362a5a66 100644 --- a/mindspore/ccsrc/frontend/optimizer/optimizer.h +++ b/mindspore/ccsrc/frontend/optimizer/optimizer.h @@ -43,10 +43,10 @@ using OptimizeGraphFunc = std::function &list, bool is_once = false) - : list_(list), is_once_(is_once) {} - OptPassConfig(const std::initializer_list &list, bool is_once = false) - : list_(list), is_once_(is_once) {} + explicit OptPassConfig(const std::vector &list, bool is_once = false, bool global_sensitive = false) + : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {} + OptPassConfig(const std::initializer_list &list, bool is_once = false, bool global_sensitive = false) + : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {} ~OptPassConfig() = default; const std::vector &list() const { return list_; } @@ -57,6 +57,8 @@ class OptPassConfig { const bool is_once() const { return is_once_; } + const bool global_sensitive() const { return global_sensitive_; } + private: OptPassConfig() : is_renormalize_(true) {} @@ -64,6 +66,7 @@ class OptPassConfig { std::vector list_; bool is_renormalize_{false}; bool is_once_{false}; + bool global_sensitive_{false}; }; class OptPass { @@ -115,7 +118,7 @@ class Optimizer : public std::enable_shared_from_this { } if (config.list().size() > 0) { - OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once()); + OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive()); passes_.push_back(OptPass(func)); continue; } diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index af0b2d6cf6..7b26234f85 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -136,19 +136,21 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.stopgrad_eliminater_, irpass.sparse_tensor_eliminate_, }); - opt::OptPassConfig a_2 = opt::OptPassConfig({ - irpass.merge_addn_, - irpass.float_tuple_getitem_switch_, - irpass.float_env_getitem_switch_, - irpass.incorporate_getitem_set_, - irpass.incorporate_call_, - irpass.incorporate_call_switch_, - irpass.incorporate_env_getitem_bypass_recursive_, - irpass.incorporate_env_getitem_switch_, - irpass.new_env_get_item_, - irpass.depend_value_elim_, - irpass.all_reduce_const_elim_, - }); + opt::OptPassConfig a_2 = opt::OptPassConfig( + { + irpass.merge_addn_, + irpass.float_tuple_getitem_switch_, + irpass.float_env_getitem_switch_, + irpass.incorporate_getitem_set_, + irpass.incorporate_call_, + irpass.incorporate_call_switch_, + irpass.incorporate_env_getitem_bypass_recursive_, + irpass.incorporate_env_getitem_switch_, + irpass.new_env_get_item_, + irpass.depend_value_elim_, + irpass.all_reduce_const_elim_, + }, + false, true); opt::OptPassConfig a_after_grad = opt::OptPassConfig({ irpass.inline_without_move_, }); @@ -229,7 +231,8 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, - irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); + irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}, + false, true); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_,