| @@ -51,8 +51,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||
| special_op_eliminate_ = | |||
| MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType, | |||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | |||
| prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); | |||
| adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||
| @@ -72,9 +72,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | |||
| // Env Item Eliminate | |||
| env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | |||
| new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); | |||
| add_env_get_item_ = MakeSubstitution(AddEnvGetItem(), "add_env_get_item", prim::kPrimEnvGetItem); | |||
| env_get_set_item_ = MakeSubstitution(EnvGetSetItem(), "env_get_set_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_ = | |||
| MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_switch_ = | |||
| @@ -91,8 +90,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // Gradient transforms | |||
| expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); | |||
| stop_gradient_eliminate_ = | |||
| MakeSubstitution(StopGradientEliminater(), "stop_gradient_eliminate", prim::kPrimStopGradient); | |||
| minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||
| // branch culling | |||
| @@ -113,9 +110,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); | |||
| // Incorporation | |||
| incorporate_getitem_ = MakeSubstitution(IncorporateGetitem(), "incorporate_getitem", prim::kPrimTupleGetItem); | |||
| incorporate_getitem_switch_ = | |||
| MakeSubstitution(IncorporateGetitemSwitch(), "incorporate_getitem_switch", prim::kPrimTupleGetItem); | |||
| incorporate_getitem_set_ = | |||
| MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | |||
| incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); | |||
| incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); | |||
| @@ -50,9 +50,8 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr reset_defer_inline_; | |||
| // Env Item Eliminate | |||
| SubstitutionPtr env_get_item_eliminate_; | |||
| SubstitutionPtr new_env_get_item_; | |||
| SubstitutionPtr add_env_get_item_; | |||
| SubstitutionPtr env_get_set_item_; | |||
| SubstitutionPtr incorporate_env_getitem_; | |||
| SubstitutionPtr incorporate_env_getitem_switch_; | |||
| @@ -74,7 +73,6 @@ class OptimizeIRPassLib { | |||
| // Gradient irpasses | |||
| SubstitutionPtr expand_jprim_; | |||
| SubstitutionPtr stop_gradient_eliminate_; | |||
| SubstitutionPtr minmaximum_grad_; | |||
| // inline | |||
| @@ -83,8 +81,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr specialize_transform_; | |||
| // Incorporation | |||
| SubstitutionPtr incorporate_getitem_; | |||
| SubstitutionPtr incorporate_getitem_switch_; | |||
| SubstitutionPtr incorporate_getitem_set_; | |||
| SubstitutionPtr incorporate_call_; | |||
| SubstitutionPtr incorporate_call_switch_; | |||
| @@ -115,51 +112,30 @@ class InferenceOptPrepareLib { | |||
| // predicate functions | |||
| inline bool IsNode(const AnfNodePtr &) { return true; } | |||
| inline bool IsCNode(const AnfNodePtr &node) { | |||
| if (node != nullptr) { | |||
| return node->isa<CNode>(); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsCNode(const AnfNodePtr &node) { return node->isa<CNode>(); } | |||
| inline bool IsVNode(const AnfNodePtr &node) { | |||
| if (node != nullptr) { | |||
| return node->isa<ValueNode>(); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsVNode(const AnfNodePtr &node) { return node->isa<ValueNode>(); } | |||
| inline bool IsParam(const AnfNodePtr &node) { | |||
| if (node != nullptr) { | |||
| return node->isa<Parameter>(); | |||
| } | |||
| return false; | |||
| } | |||
| inline bool IsParam(const AnfNodePtr &node) { return node->isa<Parameter>(); } | |||
| // Check if CNode Input 0 is Func Graph | |||
| inline bool IsCNodeGraph(const AnfNodePtr &node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto inp0 = node->cast<CNodePtr>()->input(0); | |||
| if (IsValueNode<FuncGraph>(inp0)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return IsValueNode<FuncGraph>(inp0); | |||
| } | |||
| // Check if CNode Input 0 is CNode | |||
| inline bool IsCNodeDup(const AnfNodePtr &node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto inp0 = node->cast<CNodePtr>()->input(0); | |||
| if (inp0 != nullptr && inp0->isa<CNode>()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return (inp0 != nullptr) && inp0->isa<CNode>(); | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -225,6 +225,33 @@ class EnvGetSetItem : public AnfVisitor { | |||
| bool is_match_{false}; | |||
| }; | |||
| class EnvGetItemEliminater { | |||
| public: | |||
| EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { | |||
| eliminaters_.emplace_back(new_env_get_item_); | |||
| eliminaters_.emplace_back(add_env_get_item_); | |||
| eliminaters_.emplace_back(env_get_set_item_); | |||
| } | |||
| ~EnvGetItemEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| private: | |||
| NewEnvGetItem new_env_get_item_; | |||
| AddEnvGetItem add_env_get_item_; | |||
| EnvGetSetItem env_get_set_item_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| }; | |||
| // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} | |||
| class IncorporateEnvGetitem : public AnfVisitor { | |||
| public: | |||
| @@ -55,21 +55,6 @@ class ExpandJPrim : public AnfVisitor { | |||
| private: | |||
| ValueNodePtr x_{nullptr}; | |||
| }; | |||
| // stop_gradient(x) ==> x | |||
| class StopGradientEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| x_ = nullptr; | |||
| AnfVisitor::Match(prim::kPrimStopGradient)(node); | |||
| return x_; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { x_ = node; } | |||
| private: | |||
| AnfNodePtr x_{nullptr}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -197,6 +197,31 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||
| std::vector<AnfNodePtr> args_{}; | |||
| internal::GetitemTransform getitem_transform_; | |||
| }; | |||
| class IncorporateGetitemSet { | |||
| public: | |||
| IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { | |||
| eliminaters_.emplace_back(incorporate_getitem_); | |||
| eliminaters_.emplace_back(incorporate_getitem_switch_); | |||
| } | |||
| ~IncorporateGetitemSet() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| private: | |||
| IncorporateGetitem incorporate_getitem_; | |||
| IncorporateGetitemSwitch incorporate_getitem_switch_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -35,12 +35,14 @@ class SpecialOpEliminater { | |||
| public: | |||
| SpecialOpEliminater() | |||
| : insert_gradient_of_(prim::kPrimInsertGradientOf), | |||
| stop_gradient_(prim::kPrimStopGradient), | |||
| hook_backward_(prim::kPrimHookBackward), | |||
| print_shape_type_(prim::kPrimPrintShapeType), | |||
| get_ref_value_(prim::kPrimGetRefValue), | |||
| mirror_(prim::kPrimMirror), | |||
| virtual_div_(prim::kPrimVirtualDiv) { | |||
| eliminaters_.emplace_back(insert_gradient_of_); | |||
| eliminaters_.emplace_back(stop_gradient_); | |||
| eliminaters_.emplace_back(hook_backward_); | |||
| eliminaters_.emplace_back(print_shape_type_); | |||
| eliminaters_.emplace_back(get_ref_value_); | |||
| @@ -61,7 +63,8 @@ class SpecialOpEliminater { | |||
| } | |||
| private: | |||
| PrimEliminater insert_gradient_of_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; | |||
| PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, | |||
| virtual_div_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| }; | |||
| @@ -44,8 +44,17 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto inp0 = cnode->input(0); | |||
| auto prim0 = GetValueNode<PrimitivePtr>(inp0); | |||
| if (prim0 == nullptr) { | |||
| return false; | |||
| } | |||
| auto hash = prim0->Hash(); | |||
| auto const &name = prim0->name(); | |||
| for (auto &prim : prims) { | |||
| if (IsPrimitiveCNode(node, prim)) { | |||
| if (hash == prim->Hash() && name == prim->name()) { | |||
| return true; | |||
| } | |||
| } | |||
| @@ -171,7 +180,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||
| } | |||
| #ifdef ENABLE_PROFILE | |||
| MsProfile::StatTime("opt.transform", GetTime() - start); | |||
| MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start); | |||
| #endif | |||
| return changes; | |||
| } | |||
| @@ -79,16 +79,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| // Specialization | |||
| irpass.specialize_transform_, | |||
| // Arithmetic simplifications | |||
| irpass.arithmetic_simplify_, | |||
| irpass.addn_zero_filter_, | |||
| irpass.adjust_all_reduce_mul_add_, | |||
| // Miscellaneous | |||
| irpass.item_tuple_eliminate_, | |||
| irpass.env_get_set_item_, | |||
| irpass.new_env_get_item_, | |||
| irpass.add_env_get_item_, | |||
| irpass.env_get_item_eliminate_, | |||
| irpass.cast_eliminate_, | |||
| irpass.reshape_eliminate_, | |||
| irpass.reduce_eliminate_, | |||
| @@ -96,13 +89,20 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.transpose_eliminate_, | |||
| irpass.minmaximum_grad_, | |||
| irpass.get_make_ref_eliminate_, | |||
| // Arithmetic simplifications | |||
| irpass.arithmetic_simplify_, | |||
| irpass.addn_zero_filter_, | |||
| irpass.adjust_all_reduce_mul_add_, | |||
| // Safe inlining | |||
| irpass.inline_, | |||
| }); | |||
| opt::OptPassConfig a_2 = opt::OptPassConfig({ | |||
| irpass.merge_addn_, | |||
| irpass.float_tuple_getitem_switch_, | |||
| irpass.float_env_getitem_switch_, | |||
| irpass.incorporate_getitem_, | |||
| irpass.incorporate_getitem_switch_, | |||
| irpass.incorporate_getitem_set_, | |||
| irpass.incorporate_call_, | |||
| irpass.incorporate_call_switch_, | |||
| irpass.incorporate_env_getitem_, | |||
| @@ -145,7 +145,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.reset_defer_inline_, | |||
| irpass.inline_, | |||
| irpass.special_op_eliminate_, | |||
| irpass.stop_gradient_eliminate_, | |||
| irpass.get_make_ref_eliminate_, | |||
| }); | |||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | |||
| @@ -401,7 +401,7 @@ TEST_F(TestOptLib, test_incorporate_getitem) { | |||
| FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after1"); | |||
| FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after2"); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_}); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_set_}); | |||
| ASSERT_TRUE(CheckOpt(before1, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(before2, after2, patterns)); | |||
| @@ -411,7 +411,7 @@ TEST_F(TestOptLib, test_incorporate_getitem_through_switch) { | |||
| FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "before"); | |||
| FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "after"); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_switch_}); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_set_}); | |||
| ASSERT_TRUE(CheckOpt(before, after, patterns)); | |||
| } | |||