diff --git a/mindspore/ccsrc/ir/optimizer_caller.h b/mindspore/ccsrc/ir/optimizer_caller.h index bd30454147..036f4ab510 100644 --- a/mindspore/ccsrc/ir/optimizer_caller.h +++ b/mindspore/ccsrc/ir/optimizer_caller.h @@ -17,13 +17,23 @@ #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ +#include + #include "ir/anf.h" -#include "optimizer/opt.h" namespace mindspore { +namespace opt { +class Optimizer; +using OptimizerPtr = std::shared_ptr; +using OptimizerWeakPtr = std::weak_ptr; + +using PredicateFuncType = std::function; +} // namespace opt + class OptimizerCaller { public: virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } }; +using OptimizerCallerPtr = std::shared_ptr; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 0033e386d8..0996abee2c 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -14,140 +14,154 @@ * limitations under the License. */ -#include "optimizer/irpass.h" - #include -#include "optimizer/irpass/symbol_resolver.h" +#include "optimizer/irpass.h" #include "optimizer/irpass/arithmetic_simplify.h" -#include "optimizer/irpass/special_op_eliminate.h" -#include "optimizer/irpass/item_tuple_eliminate.h" -#include "optimizer/irpass/env_item_eliminate.h" -#include "optimizer/irpass/tile_eliminate.h" -#include "optimizer/irpass/cast_eliminate.h" -#include "optimizer/irpass/reshape_eliminate.h" -#include "optimizer/irpass/transpose_eliminate.h" -#include "optimizer/irpass/reduce_eliminate.h" -#include "optimizer/irpass/partial_eliminate.h" -#include "optimizer/irpass/ref_eliminate.h" -#include "optimizer/irpass/merge_addn.h" #include "optimizer/irpass/branch_culling.h" +#include "optimizer/irpass/cast_eliminate.h" +#include "optimizer/irpass/convert.h" +#include "optimizer/irpass/env_item_eliminate.h" +#include "optimizer/irpass/grad_var_prepare.h" #include "optimizer/irpass/gradient_eliminate.h" -#include "optimizer/irpass/minmax_grad.h" #include "optimizer/irpass/inline.h" -#include "optimizer/irpass/convert.h" -#include "optimizer/irpass/specialize_transform.h" -#include "optimizer/irpass/incorporate_getitem.h" #include "optimizer/irpass/incorporate_call.h" -#include "optimizer/irpass/grad_var_prepare.h" -#include "optimizer/irpass/param_replace.h" +#include "optimizer/irpass/incorporate_getitem.h" +#include "optimizer/irpass/item_tuple_eliminate.h" #include "optimizer/irpass/mark_interface_fusion.h" +#include "optimizer/irpass/merge_addn.h" +#include "optimizer/irpass/minmax_grad.h" +#include "optimizer/irpass/param_replace.h" +#include "optimizer/irpass/partial_eliminate.h" +#include "optimizer/irpass/reduce_eliminate.h" +#include "optimizer/irpass/ref_eliminate.h" +#include "optimizer/irpass/reshape_eliminate.h" +#include "optimizer/irpass/special_op_eliminate.h" +#include "optimizer/irpass/specialize_transform.h" +#include "optimizer/irpass/symbol_resolver.h" +#include "optimizer/irpass/tile_eliminate.h" +#include "optimizer/irpass/transpose_eliminate.h" #include "optimizer/opt.h" namespace mindspore { namespace opt { namespace irpass { OptimizeIRPassLib::OptimizeIRPassLib() { - arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", + arithmetic_simplify_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); - arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); + arithmetic_simplify2_ = + MakeSubstitution(std::make_shared(), "arithmetic_simplify2", {prim::kPrimMul}); special_op_eliminate_ = - MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", + MakeSubstitution(std::make_shared(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); - zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); - adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); + zero_like_fill_zero_ = + MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); + adjust_all_reduce_mul_add_ = + MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); // ops eliminate - item_tuple_eliminate_ = - MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); - tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile); - cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast); - reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape); - transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose); + item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", + {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); + tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); + cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); + reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); + transpose_eliminate_ = + MakeSubstitution(std::make_shared(), "transpose_eliminate", prim::kPrimTranspose); reduce_eliminate_ = MakeSubstitution( - ReduceOneEliminater(), "reduce_eliminate", + std::make_shared(), "reduce_eliminate", {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); - partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); - same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); - check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); - reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode); - depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); + partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); + same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); + check_bprop_eliminate_ = + MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); + reset_defer_inline_ = + MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); + depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); // Env Item Eliminate - env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); - new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); + env_get_item_eliminate_ = + MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); + new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_ = - MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_switch_ = - MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); + MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), + "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); // Ref eliminate - make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); - get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", + make_ref_eliminate_ = + MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); + get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", + get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - replace_refkey_by_param_ = - MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode, opt::FORCE_RENORM); - replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); + replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", + IsValueNode, opt::FORCE_RENORM); + replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); // Gradient transforms - expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); - minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); + expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); + minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); // branch culling - switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch); - float_tuple_getitem_switch_ = - MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); + switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); + float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), + "float_tuple_getitem_switch", prim::kPrimTupleGetItem); float_env_getitem_switch_ = - MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem); - convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup); + MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); + convert_switch_replacement_ = + MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); // Addn - merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN); - addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN); + merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); + addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); // inline - inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph); - replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode); - specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); + inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); + replace_applicator_ = + MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); + specialize_transform_ = + MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); // Incorporation incorporate_getitem_set_ = - MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); - incorporate_getitem_from_param_ = - MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); - incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); - incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); + MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); + incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared(), + "incorporate_getitem_from_param", IsCNodeGraphKernel); + incorporate_call_ = MakeSubstitution(std::make_shared(), "incorporate_call", IsCNodeDup); + incorporate_call_switch_ = + MakeSubstitution(std::make_shared(), "incorporate_call_switch", IsCNodeDup); // Virtual Dataset - virtual_dataset_eliminate_ = - MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); + virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), + "virtual_dataset_eliminate", prim::kPrimVirtualDataset); // Convert - print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); + print_tuple_wrapper_ = + MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); // Unused parameter eliminate unused_parameter_eliminate_ = - MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); - unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); + MakeSubstitution(std::make_shared(), "unused_parameter_eliminate", IsCNodeGraphKernel); + unused_output_eliminate_ = + MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); // AddN eliminate - addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); + addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); // Mark interface fusion - mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); + mark_interface_fusion_ = + MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); } ResolveIRPassLib::ResolveIRPassLib() { - resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); - resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); + resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); + resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); } InferenceOptPrepareLib::InferenceOptPrepareLib() { - grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); + grad_var_prepare_ = MakeSubstitution(std::make_shared(), "grad_var_prepare", IsCNode); } } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 270db8305f..a26b81e952 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -17,15 +17,16 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ -#include -#include #include +#include +#include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" +#include "ir/optimizer_caller.h" #include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { @@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor { FuncGraphPtr all_reduce_fg_{nullptr}; }; -class ArithmeticSimplify { +class ArithmeticSimplify : public OptimizerCaller { public: ArithmeticSimplify() - : multiply_by_zero_or_one_(), - tensor_multiply_by_one_(), - add_by_zero_(), - tensor_add_by_zero_(), - identity_(prim::kPrimIdentity), - opt_update_zero_tensor_(), - constant_duplicate_mul_(), - power_one_() { + : multiply_by_zero_or_one_(std::make_shared()), + tensor_multiply_by_one_(std::make_shared()), + add_by_zero_(std::make_shared()), + tensor_add_by_zero_(std::make_shared()), + identity_(std::make_shared(prim::kPrimIdentity)), + opt_update_zero_tensor_(std::make_shared()), + constant_duplicate_mul_(std::make_shared()), + power_one_(std::make_shared()) { eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(tensor_multiply_by_one_); eliminaters_.emplace_back(add_by_zero_); @@ -761,10 +762,10 @@ class ArithmeticSimplify { } ~ArithmeticSimplify() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -773,15 +774,9 @@ class ArithmeticSimplify { } private: - MultiplyByZeroOrOne multiply_by_zero_or_one_; - TensorMultiplyByOne tensor_multiply_by_one_; - AddByZero add_by_zero_; - TensorAddByZero tensor_add_by_zero_; - PrimEliminater identity_; - OptUpdateZeroTensor opt_update_zero_tensor_; - ConstantDuplicateMul constant_duplicate_mul_; - PowerOneEliminate power_one_; - std::vector eliminaters_{}; + OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_, + opt_update_zero_tensor_, constant_duplicate_mul_, power_one_; + std::vector eliminaters_{}; }; // Arithmetic Simplifications should be done after step_parallel. @@ -789,15 +784,17 @@ class ArithmeticSimplify { // with shape(weight), but after step_parallel, shape of weight may be changed, so the // shape of the constant tensor should also be changed. So this pass is seperated from // ArithmeticSimplify and deferred until step_parallel. -class ArithmeticSimplify2 { +class ArithmeticSimplify2 : public OptimizerCaller { public: - ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } + ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { + eliminaters_.emplace_back(tensor_multiply_by_zero_); + } ~ArithmeticSimplify2() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -806,8 +803,8 @@ class ArithmeticSimplify2 { } private: - TensorMultiplyByZero tensor_multiply_by_zero_; - std::vector eliminaters_{}; + OptimizerCallerPtr tensor_multiply_by_zero_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h index 734d88cb10..d98d0b677b 100644 --- a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#include "ir/visitor.h" #include "optimizer/irpass.h" #include "optimizer/optimizer.h" -#include "ir/visitor.h" namespace mindspore { namespace opt { @@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}, t_{nullptr}; }; -class CastEliminater { +class CastEliminater : public OptimizerCaller { public: CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} ~CastEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { auto new_node = cast_same_type_eliminater_(optimizer, node); if (new_node != nullptr) { return new_node; diff --git a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h index 0f59c69fef..3f100dcaec 100644 --- a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h @@ -17,18 +17,19 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ -#include -#include #include -#include #include +#include +#include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" #include "utils/symbolic.h" namespace mindspore { @@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor { bool is_match_{false}; }; -class EnvGetItemEliminater { +class EnvGetItemEliminater : public OptimizerCaller { public: - EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { + EnvGetItemEliminater() + : new_env_get_item_(std::make_shared()), + add_env_get_item_(std::make_shared()), + env_get_set_item_(std::make_shared()) { eliminaters_.emplace_back(new_env_get_item_); eliminaters_.emplace_back(add_env_get_item_); eliminaters_.emplace_back(env_get_set_item_); } ~EnvGetItemEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -246,10 +250,8 @@ class EnvGetItemEliminater { } private: - NewEnvGetItem new_env_get_item_; - AddEnvGetItem add_env_get_item_; - EnvGetSetItem env_get_set_item_; - std::vector eliminaters_{}; + OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; + std::vector eliminaters_{}; }; // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} diff --git a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h index 5afee45e95..b6c8fb0e18 100644 --- a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h @@ -17,18 +17,20 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ -#include #include -#include #include +#include #include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" + namespace mindspore { namespace opt { namespace irpass { @@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor { internal::GetitemTransform getitem_transform_; }; -class IncorporateGetitemSet { +class IncorporateGetitemSet : public OptimizerCaller { public: - IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { + IncorporateGetitemSet() + : incorporate_getitem_(std::make_shared()), + incorporate_getitem_switch_(std::make_shared()) { eliminaters_.emplace_back(incorporate_getitem_); eliminaters_.emplace_back(incorporate_getitem_switch_); } ~IncorporateGetitemSet() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -403,9 +407,8 @@ class IncorporateGetitemSet { } private: - IncorporateGetitem incorporate_getitem_; - IncorporateGetitemSwitch incorporate_getitem_switch_; - std::vector eliminaters_{}; + OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h index 21cdff51ad..202951a254 100644 --- a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h @@ -17,13 +17,15 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ -#include #include +#include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" +#include "ir/optimizer_caller.h" #include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { @@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor { AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; }; -class ItemTupleEliminater { +class ItemTupleEliminater : public OptimizerCaller { public: ItemTupleEliminater() - : get_item_eliminater_(), - get_item_const_eliminater_(), - set_item_eliminater_(), - get_set_item_eliminater_(), - get_item_depend_reorder_() { + : get_item_eliminater_(std::make_shared()), + get_item_const_eliminater_(std::make_shared()), + set_item_eliminater_(std::make_shared()), + get_set_item_eliminater_(std::make_shared()), + get_item_depend_reorder_(std::make_shared()) { eliminaters_.emplace_back(get_item_eliminater_); eliminaters_.emplace_back(get_item_const_eliminater_); eliminaters_.emplace_back(set_item_eliminater_); @@ -277,10 +279,10 @@ class ItemTupleEliminater { } ~ItemTupleEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -289,12 +291,9 @@ class ItemTupleEliminater { } private: - GetitemEliminater get_item_eliminater_; - GetitemConstEliminater get_item_const_eliminater_; - SetitemEliminater set_item_eliminater_; - GetSetitemEliminater get_set_item_eliminater_; - GetitemDependReorder get_item_depend_reorder_; - std::vector eliminaters_{}; + OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, + get_item_depend_reorder_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h index 41f379221c..6d81b401c3 100644 --- a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h @@ -19,9 +19,9 @@ #include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" #include "ir/pattern_matcher.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { diff --git a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h index fb43f6ffd8..cafc8b796c 100644 --- a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h @@ -19,11 +19,12 @@ #include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" #include "ir/func_graph.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" #include "pipeline/static_analysis/dshape.h" namespace mindspore { @@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}, shape_{nullptr}; }; -class ReshapeEliminater { +class ReshapeEliminater : public OptimizerCaller { public: ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} ~ReshapeEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { auto new_node = reshape_same_shape_eliminater_(optimizer, node); if (new_node != nullptr) { return new_node; diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index dcba80431a..b6a4e1c852 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -18,31 +18,31 @@ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ #include -#include -#include #include +#include +#include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" #include "ir/optimizer_caller.h" -#include "optimizer/irpass/prim_eliminate.h" +#include "ir/pattern_matcher.h" #include "ir/visitor.h" #include "operator/ops.h" -#include "ir/pattern_matcher.h" +#include "optimizer/irpass.h" +#include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { -class SpecialOpEliminater { +class SpecialOpEliminater : public OptimizerCaller { public: SpecialOpEliminater() - : insert_gradient_of_(prim::kPrimInsertGradientOf), - stop_gradient_(prim::kPrimStopGradient), - hook_backward_(prim::kPrimHookBackward), - print_shape_type_(prim::kPrimPrintShapeType), - get_ref_value_(prim::kPrimGetRefValue), - mirror_(prim::kPrimMirror), - virtual_div_(prim::kPrimVirtualDiv) { + : insert_gradient_of_(std::make_shared(prim::kPrimInsertGradientOf)), + stop_gradient_(std::make_shared(prim::kPrimStopGradient)), + hook_backward_(std::make_shared(prim::kPrimHookBackward)), + print_shape_type_(std::make_shared(prim::kPrimPrintShapeType)), + get_ref_value_(std::make_shared(prim::kPrimGetRefValue)), + mirror_(std::make_shared(prim::kPrimMirror)), + virtual_div_(std::make_shared(prim::kPrimVirtualDiv)) { eliminaters_.emplace_back(insert_gradient_of_); eliminaters_.emplace_back(stop_gradient_); eliminaters_.emplace_back(hook_backward_); @@ -53,10 +53,10 @@ class SpecialOpEliminater { } ~SpecialOpEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -65,9 +65,9 @@ class SpecialOpEliminater { } private: - PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, + OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; - std::vector eliminaters_{}; + std::vector eliminaters_{}; }; // {PrimVirtualDataset, X} -> X diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 82fbcc2036..4c2e85157f 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -16,28 +16,27 @@ #include "optimizer/opt.h" +#include +#include #include #include -#include -#include #include "ir/anf.h" #include "ir/manager.h" -#include "utils/ordered_set.h" - -#include "utils/log_adapter.h" #include "optimizer/optimizer.h" +#include "utils/log_adapter.h" +#include "utils/ordered_set.h" namespace mindspore { /* namespace to support opt */ namespace opt { -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, const RenormAction &renorm_action) { auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const std::vector &prims, const RenormAction &renorm_action) { auto fn = [prims](const AnfNodePtr &node) -> bool { if (!node->isa()) { @@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &renorm_action) { return std::make_shared(transform, name, predicate, renorm_action); } -AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { +AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { #ifdef ENABLE_PROFILE double t = GetTime(); #endif - AnfNodePtr result = transform_(optimizer, node); + AnfNodePtr result = (*transform_)(optimizer, node); #ifdef ENABLE_PROFILE if (optimizer != nullptr) { auto time = GetTime(); diff --git a/mindspore/ccsrc/optimizer/opt.h b/mindspore/ccsrc/optimizer/opt.h index fb0bdc58be..6601d969d2 100644 --- a/mindspore/ccsrc/optimizer/opt.h +++ b/mindspore/ccsrc/optimizer/opt.h @@ -17,24 +17,18 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ -#include -#include #include +#include +#include #include "ir/anf.h" #include "ir/func_graph.h" +#include "ir/optimizer_caller.h" #include "operator/ops.h" namespace mindspore { /* namespace to support opt */ namespace opt { -class Optimizer; - -using OptimizerPtr = std::shared_ptr; -using OptimizerWeakPtr = std::weak_ptr; - -using PredicateFuncType = std::function; -using TransformFuncType = std::function; // Define the interaction mode between an Optimize pass and Renormalize pass // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed @@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; class Substitution { public: - TransformFuncType transform_{nullptr}; + OptimizerCallerPtr transform_; std::string name_; PredicateFuncType predicate_{nullptr}; // an enum to mark this Substitution relation to renormalize pass RenormAction renorm_action_; - Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, + Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &renorm_action) : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} ~Substitution() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); }; using SubstitutionPtr = std::shared_ptr; -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const std::vector &prims, const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); class SubstitutionList { diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index 05e7e6b978..2428d0dddb 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common { }; void SetUp() { - elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd); - elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R); - idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P); - Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q); + elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); + elim_R = MakeSubstitution(std::make_shared(R), "elim_R", R); + idempotent_P = MakeSubstitution(std::make_shared(), "idempotent_P", P); + Qct_to_P = MakeSubstitution(std::make_shared(), "Qct_to_P", Q); } bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {