| @@ -184,9 +184,6 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con | |||
| for (auto &substitution : list_) { | |||
| auto res = DoTransform(optimizer, node, substitution); | |||
| if (res != nullptr) { | |||
| if (is_once_) { | |||
| return true; | |||
| } | |||
| change = true; | |||
| changes = true; | |||
| node = res; | |||
| @@ -228,9 +225,6 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons | |||
| bool change = false; | |||
| auto res = DoTransform(optimizer, node, substitution); | |||
| if (res != nullptr) { | |||
| if (is_once_) { | |||
| return true; | |||
| } | |||
| change = true; | |||
| changes = true; | |||
| node = res; | |||
| @@ -316,9 +310,13 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize | |||
| : kOptTraverseFromSubstitutionsToIR); | |||
| if (traverse_mode == kOptTraverseFromIRToSubstitutions && | |||
| MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | |||
| optimizer->traverse_nodes_first()) { | |||
| optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) { | |||
| MS_LOG(DEBUG) << "IR >> SUB, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_" | |||
| << optimizer->CurPass_.name; | |||
| changes = ApplyIRToSubstitutions(optimizer, func_graph); | |||
| } else { | |||
| MS_LOG(DEBUG) << "SUB >> IR, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_" | |||
| << optimizer->CurPass_.name; | |||
| changes = ApplySubstitutionsToIR(optimizer, func_graph); | |||
| } | |||
| return changes; | |||
| @@ -63,8 +63,9 @@ enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptT | |||
| class SubstitutionList { | |||
| public: | |||
| explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false) | |||
| : list_(patterns), is_once_(is_once) {} | |||
| explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false, | |||
| bool global_sensitive = false) | |||
| : list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {} | |||
| ~SubstitutionList() = default; | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; | |||
| @@ -77,6 +78,7 @@ class SubstitutionList { | |||
| std::vector<SubstitutionPtr> list_; | |||
| // a flag to mark this list of Substitution can only be executed only once | |||
| bool is_once_; | |||
| bool global_sensitive_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -43,10 +43,10 @@ using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, con | |||
| class OptPassConfig { | |||
| public: | |||
| explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} | |||
| explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false) | |||
| : list_(list), is_once_(is_once) {} | |||
| OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false) | |||
| : list_(list), is_once_(is_once) {} | |||
| explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false) | |||
| : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {} | |||
| OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false) | |||
| : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {} | |||
| ~OptPassConfig() = default; | |||
| const std::vector<SubstitutionPtr> &list() const { return list_; } | |||
| @@ -57,6 +57,8 @@ class OptPassConfig { | |||
| const bool is_once() const { return is_once_; } | |||
| const bool global_sensitive() const { return global_sensitive_; } | |||
| private: | |||
| OptPassConfig() : is_renormalize_(true) {} | |||
| @@ -64,6 +66,7 @@ class OptPassConfig { | |||
| std::vector<SubstitutionPtr> list_; | |||
| bool is_renormalize_{false}; | |||
| bool is_once_{false}; | |||
| bool global_sensitive_{false}; | |||
| }; | |||
| class OptPass { | |||
| @@ -115,7 +118,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| } | |||
| if (config.list().size() > 0) { | |||
| OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once()); | |||
| OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive()); | |||
| passes_.push_back(OptPass(func)); | |||
| continue; | |||
| } | |||
| @@ -136,19 +136,21 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.stopgrad_eliminater_, | |||
| irpass.sparse_tensor_eliminate_, | |||
| }); | |||
| opt::OptPassConfig a_2 = opt::OptPassConfig({ | |||
| irpass.merge_addn_, | |||
| irpass.float_tuple_getitem_switch_, | |||
| irpass.float_env_getitem_switch_, | |||
| irpass.incorporate_getitem_set_, | |||
| irpass.incorporate_call_, | |||
| irpass.incorporate_call_switch_, | |||
| irpass.incorporate_env_getitem_bypass_recursive_, | |||
| irpass.incorporate_env_getitem_switch_, | |||
| irpass.new_env_get_item_, | |||
| irpass.depend_value_elim_, | |||
| irpass.all_reduce_const_elim_, | |||
| }); | |||
| opt::OptPassConfig a_2 = opt::OptPassConfig( | |||
| { | |||
| irpass.merge_addn_, | |||
| irpass.float_tuple_getitem_switch_, | |||
| irpass.float_env_getitem_switch_, | |||
| irpass.incorporate_getitem_set_, | |||
| irpass.incorporate_call_, | |||
| irpass.incorporate_call_switch_, | |||
| irpass.incorporate_env_getitem_bypass_recursive_, | |||
| irpass.incorporate_env_getitem_switch_, | |||
| irpass.new_env_get_item_, | |||
| irpass.depend_value_elim_, | |||
| irpass.all_reduce_const_elim_, | |||
| }, | |||
| false, true); | |||
| opt::OptPassConfig a_after_grad = opt::OptPassConfig({ | |||
| irpass.inline_without_move_, | |||
| }); | |||
| @@ -229,7 +231,8 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, | |||
| irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | |||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | |||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | |||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}, | |||
| false, true); | |||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | |||
| irpass.replace_refkey_by_param_, | |||
| irpass.make_ref_eliminate_, | |||