| @@ -184,9 +184,6 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con | |||||
| for (auto &substitution : list_) { | for (auto &substitution : list_) { | ||||
| auto res = DoTransform(optimizer, node, substitution); | auto res = DoTransform(optimizer, node, substitution); | ||||
| if (res != nullptr) { | if (res != nullptr) { | ||||
| if (is_once_) { | |||||
| return true; | |||||
| } | |||||
| change = true; | change = true; | ||||
| changes = true; | changes = true; | ||||
| node = res; | node = res; | ||||
| @@ -228,9 +225,6 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons | |||||
| bool change = false; | bool change = false; | ||||
| auto res = DoTransform(optimizer, node, substitution); | auto res = DoTransform(optimizer, node, substitution); | ||||
| if (res != nullptr) { | if (res != nullptr) { | ||||
| if (is_once_) { | |||||
| return true; | |||||
| } | |||||
| change = true; | change = true; | ||||
| changes = true; | changes = true; | ||||
| node = res; | node = res; | ||||
| @@ -316,9 +310,13 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize | |||||
| : kOptTraverseFromSubstitutionsToIR); | : kOptTraverseFromSubstitutionsToIR); | ||||
| if (traverse_mode == kOptTraverseFromIRToSubstitutions && | if (traverse_mode == kOptTraverseFromIRToSubstitutions && | ||||
| MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | ||||
| optimizer->traverse_nodes_first()) { | |||||
| optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) { | |||||
| MS_LOG(DEBUG) << "IR >> SUB, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_" | |||||
| << optimizer->CurPass_.name; | |||||
| changes = ApplyIRToSubstitutions(optimizer, func_graph); | changes = ApplyIRToSubstitutions(optimizer, func_graph); | ||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "SUB >> IR, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_" | |||||
| << optimizer->CurPass_.name; | |||||
| changes = ApplySubstitutionsToIR(optimizer, func_graph); | changes = ApplySubstitutionsToIR(optimizer, func_graph); | ||||
| } | } | ||||
| return changes; | return changes; | ||||
| @@ -63,8 +63,9 @@ enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptT | |||||
| class SubstitutionList { | class SubstitutionList { | ||||
| public: | public: | ||||
| explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false) | |||||
| : list_(patterns), is_once_(is_once) {} | |||||
| explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false, | |||||
| bool global_sensitive = false) | |||||
| : list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {} | |||||
| ~SubstitutionList() = default; | ~SubstitutionList() = default; | ||||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; | bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; | ||||
| @@ -77,6 +78,7 @@ class SubstitutionList { | |||||
| std::vector<SubstitutionPtr> list_; | std::vector<SubstitutionPtr> list_; | ||||
| // a flag to mark this list of Substitution can only be executed only once | // a flag to mark this list of Substitution can only be executed only once | ||||
| bool is_once_; | bool is_once_; | ||||
| bool global_sensitive_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -43,10 +43,10 @@ using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, con | |||||
| class OptPassConfig { | class OptPassConfig { | ||||
| public: | public: | ||||
| explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} | explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} | ||||
| explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false) | |||||
| : list_(list), is_once_(is_once) {} | |||||
| OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false) | |||||
| : list_(list), is_once_(is_once) {} | |||||
| explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false) | |||||
| : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {} | |||||
| OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false) | |||||
| : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {} | |||||
| ~OptPassConfig() = default; | ~OptPassConfig() = default; | ||||
| const std::vector<SubstitutionPtr> &list() const { return list_; } | const std::vector<SubstitutionPtr> &list() const { return list_; } | ||||
| @@ -57,6 +57,8 @@ class OptPassConfig { | |||||
| const bool is_once() const { return is_once_; } | const bool is_once() const { return is_once_; } | ||||
| const bool global_sensitive() const { return global_sensitive_; } | |||||
| private: | private: | ||||
| OptPassConfig() : is_renormalize_(true) {} | OptPassConfig() : is_renormalize_(true) {} | ||||
| @@ -64,6 +66,7 @@ class OptPassConfig { | |||||
| std::vector<SubstitutionPtr> list_; | std::vector<SubstitutionPtr> list_; | ||||
| bool is_renormalize_{false}; | bool is_renormalize_{false}; | ||||
| bool is_once_{false}; | bool is_once_{false}; | ||||
| bool global_sensitive_{false}; | |||||
| }; | }; | ||||
| class OptPass { | class OptPass { | ||||
| @@ -115,7 +118,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| } | } | ||||
| if (config.list().size() > 0) { | if (config.list().size() > 0) { | ||||
| OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once()); | |||||
| OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive()); | |||||
| passes_.push_back(OptPass(func)); | passes_.push_back(OptPass(func)); | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -136,19 +136,21 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.stopgrad_eliminater_, | irpass.stopgrad_eliminater_, | ||||
| irpass.sparse_tensor_eliminate_, | irpass.sparse_tensor_eliminate_, | ||||
| }); | }); | ||||
| opt::OptPassConfig a_2 = opt::OptPassConfig({ | |||||
| irpass.merge_addn_, | |||||
| irpass.float_tuple_getitem_switch_, | |||||
| irpass.float_env_getitem_switch_, | |||||
| irpass.incorporate_getitem_set_, | |||||
| irpass.incorporate_call_, | |||||
| irpass.incorporate_call_switch_, | |||||
| irpass.incorporate_env_getitem_bypass_recursive_, | |||||
| irpass.incorporate_env_getitem_switch_, | |||||
| irpass.new_env_get_item_, | |||||
| irpass.depend_value_elim_, | |||||
| irpass.all_reduce_const_elim_, | |||||
| }); | |||||
| opt::OptPassConfig a_2 = opt::OptPassConfig( | |||||
| { | |||||
| irpass.merge_addn_, | |||||
| irpass.float_tuple_getitem_switch_, | |||||
| irpass.float_env_getitem_switch_, | |||||
| irpass.incorporate_getitem_set_, | |||||
| irpass.incorporate_call_, | |||||
| irpass.incorporate_call_switch_, | |||||
| irpass.incorporate_env_getitem_bypass_recursive_, | |||||
| irpass.incorporate_env_getitem_switch_, | |||||
| irpass.new_env_get_item_, | |||||
| irpass.depend_value_elim_, | |||||
| irpass.all_reduce_const_elim_, | |||||
| }, | |||||
| false, true); | |||||
| opt::OptPassConfig a_after_grad = opt::OptPassConfig({ | opt::OptPassConfig a_after_grad = opt::OptPassConfig({ | ||||
| irpass.inline_without_move_, | irpass.inline_without_move_, | ||||
| }); | }); | ||||
| @@ -229,7 +231,8 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, | irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, | ||||
| irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | ||||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | ||||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | |||||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}, | |||||
| false, true); | |||||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | opt::OptPassConfig b_2 = opt::OptPassConfig({ | ||||
| irpass.replace_refkey_by_param_, | irpass.replace_refkey_by_param_, | ||||
| irpass.make_ref_eliminate_, | irpass.make_ref_eliminate_, | ||||