Browse Source

Make a_2 and b_1 pass as global sensitive, and traverse from SUB to IR.

tags/v1.2.0-rc1
Zhang Qinghua 4 years ago
parent
commit
f6a2ff3384
4 changed files with 34 additions and 28 deletions
  1. +5
    -7
      mindspore/ccsrc/frontend/optimizer/opt.cc
  2. +4
    -2
      mindspore/ccsrc/frontend/optimizer/opt.h
  3. +8
    -5
      mindspore/ccsrc/frontend/optimizer/optimizer.h
  4. +17
    -14
      mindspore/ccsrc/pipeline/jit/pass.cc

+ 5
- 7
mindspore/ccsrc/frontend/optimizer/opt.cc View File

@@ -184,9 +184,6 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
for (auto &substitution : list_) { for (auto &substitution : list_) {
auto res = DoTransform(optimizer, node, substitution); auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) { if (res != nullptr) {
if (is_once_) {
return true;
}
change = true; change = true;
changes = true; changes = true;
node = res; node = res;
@@ -228,9 +225,6 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
bool change = false; bool change = false;
auto res = DoTransform(optimizer, node, substitution); auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) { if (res != nullptr) {
if (is_once_) {
return true;
}
change = true; change = true;
changes = true; changes = true;
node = res; node = res;
@@ -316,9 +310,13 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
: kOptTraverseFromSubstitutionsToIR); : kOptTraverseFromSubstitutionsToIR);
if (traverse_mode == kOptTraverseFromIRToSubstitutions && if (traverse_mode == kOptTraverseFromIRToSubstitutions &&
MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
optimizer->traverse_nodes_first()) {
optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) {
MS_LOG(DEBUG) << "IR >> SUB, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_"
<< optimizer->CurPass_.name;
changes = ApplyIRToSubstitutions(optimizer, func_graph); changes = ApplyIRToSubstitutions(optimizer, func_graph);
} else { } else {
MS_LOG(DEBUG) << "SUB >> IR, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_"
<< optimizer->CurPass_.name;
changes = ApplySubstitutionsToIR(optimizer, func_graph); changes = ApplySubstitutionsToIR(optimizer, func_graph);
} }
return changes; return changes;


+ 4
- 2
mindspore/ccsrc/frontend/optimizer/opt.h View File

@@ -63,8 +63,9 @@ enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptT


class SubstitutionList { class SubstitutionList {
public: public:
explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false)
: list_(patterns), is_once_(is_once) {}
explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false,
bool global_sensitive = false)
: list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {}
~SubstitutionList() = default; ~SubstitutionList() = default;


bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const;
@@ -77,6 +78,7 @@ class SubstitutionList {
std::vector<SubstitutionPtr> list_; std::vector<SubstitutionPtr> list_;
// a flag to mark this list of Substitution can only be executed only once // a flag to mark this list of Substitution can only be executed only once
bool is_once_; bool is_once_;
bool global_sensitive_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore


+ 8
- 5
mindspore/ccsrc/frontend/optimizer/optimizer.h View File

@@ -43,10 +43,10 @@ using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, con
class OptPassConfig { class OptPassConfig {
public: public:
explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {}
explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false)
: list_(list), is_once_(is_once) {}
OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false)
: list_(list), is_once_(is_once) {}
explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
: list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
: list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
~OptPassConfig() = default; ~OptPassConfig() = default;


const std::vector<SubstitutionPtr> &list() const { return list_; } const std::vector<SubstitutionPtr> &list() const { return list_; }
@@ -57,6 +57,8 @@ class OptPassConfig {


const bool is_once() const { return is_once_; } const bool is_once() const { return is_once_; }


const bool global_sensitive() const { return global_sensitive_; }

private: private:
OptPassConfig() : is_renormalize_(true) {} OptPassConfig() : is_renormalize_(true) {}


@@ -64,6 +66,7 @@ class OptPassConfig {
std::vector<SubstitutionPtr> list_; std::vector<SubstitutionPtr> list_;
bool is_renormalize_{false}; bool is_renormalize_{false};
bool is_once_{false}; bool is_once_{false};
bool global_sensitive_{false};
}; };


class OptPass { class OptPass {
@@ -115,7 +118,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
} }


if (config.list().size() > 0) { if (config.list().size() > 0) {
OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once());
OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive());
passes_.push_back(OptPass(func)); passes_.push_back(OptPass(func));
continue; continue;
} }


+ 17
- 14
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -136,19 +136,21 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.stopgrad_eliminater_, irpass.stopgrad_eliminater_,
irpass.sparse_tensor_eliminate_, irpass.sparse_tensor_eliminate_,
}); });
opt::OptPassConfig a_2 = opt::OptPassConfig({
irpass.merge_addn_,
irpass.float_tuple_getitem_switch_,
irpass.float_env_getitem_switch_,
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.depend_value_elim_,
irpass.all_reduce_const_elim_,
});
opt::OptPassConfig a_2 = opt::OptPassConfig(
{
irpass.merge_addn_,
irpass.float_tuple_getitem_switch_,
irpass.float_env_getitem_switch_,
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.depend_value_elim_,
irpass.all_reduce_const_elim_,
},
false, true);
opt::OptPassConfig a_after_grad = opt::OptPassConfig({ opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_, irpass.inline_without_move_,
}); });
@@ -229,7 +231,8 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_,
irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_},
false, true);
opt::OptPassConfig b_2 = opt::OptPassConfig({ opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,


Loading…
Cancel
Save