| @@ -17,13 +17,23 @@ | |||||
| #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | ||||
| #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | ||||
| #include <memory> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "optimizer/opt.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | |||||
| class Optimizer; | |||||
| using OptimizerPtr = std::shared_ptr<Optimizer>; | |||||
| using OptimizerWeakPtr = std::weak_ptr<Optimizer>; | |||||
| using PredicateFuncType = std::function<bool(const AnfNodePtr &)>; | |||||
| } // namespace opt | |||||
| class OptimizerCaller { | class OptimizerCaller { | ||||
| public: | public: | ||||
| virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } | virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } | ||||
| }; | }; | ||||
| using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | ||||
| @@ -14,140 +14,154 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "optimizer/irpass.h" | |||||
| #include <string> | #include <string> | ||||
| #include "optimizer/irpass/symbol_resolver.h" | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/irpass/arithmetic_simplify.h" | #include "optimizer/irpass/arithmetic_simplify.h" | ||||
| #include "optimizer/irpass/special_op_eliminate.h" | |||||
| #include "optimizer/irpass/item_tuple_eliminate.h" | |||||
| #include "optimizer/irpass/env_item_eliminate.h" | |||||
| #include "optimizer/irpass/tile_eliminate.h" | |||||
| #include "optimizer/irpass/cast_eliminate.h" | |||||
| #include "optimizer/irpass/reshape_eliminate.h" | |||||
| #include "optimizer/irpass/transpose_eliminate.h" | |||||
| #include "optimizer/irpass/reduce_eliminate.h" | |||||
| #include "optimizer/irpass/partial_eliminate.h" | |||||
| #include "optimizer/irpass/ref_eliminate.h" | |||||
| #include "optimizer/irpass/merge_addn.h" | |||||
| #include "optimizer/irpass/branch_culling.h" | #include "optimizer/irpass/branch_culling.h" | ||||
| #include "optimizer/irpass/cast_eliminate.h" | |||||
| #include "optimizer/irpass/convert.h" | |||||
| #include "optimizer/irpass/env_item_eliminate.h" | |||||
| #include "optimizer/irpass/grad_var_prepare.h" | |||||
| #include "optimizer/irpass/gradient_eliminate.h" | #include "optimizer/irpass/gradient_eliminate.h" | ||||
| #include "optimizer/irpass/minmax_grad.h" | |||||
| #include "optimizer/irpass/inline.h" | #include "optimizer/irpass/inline.h" | ||||
| #include "optimizer/irpass/convert.h" | |||||
| #include "optimizer/irpass/specialize_transform.h" | |||||
| #include "optimizer/irpass/incorporate_getitem.h" | |||||
| #include "optimizer/irpass/incorporate_call.h" | #include "optimizer/irpass/incorporate_call.h" | ||||
| #include "optimizer/irpass/grad_var_prepare.h" | |||||
| #include "optimizer/irpass/param_replace.h" | |||||
| #include "optimizer/irpass/incorporate_getitem.h" | |||||
| #include "optimizer/irpass/item_tuple_eliminate.h" | |||||
| #include "optimizer/irpass/mark_interface_fusion.h" | #include "optimizer/irpass/mark_interface_fusion.h" | ||||
| #include "optimizer/irpass/merge_addn.h" | |||||
| #include "optimizer/irpass/minmax_grad.h" | |||||
| #include "optimizer/irpass/param_replace.h" | |||||
| #include "optimizer/irpass/partial_eliminate.h" | |||||
| #include "optimizer/irpass/reduce_eliminate.h" | |||||
| #include "optimizer/irpass/ref_eliminate.h" | |||||
| #include "optimizer/irpass/reshape_eliminate.h" | |||||
| #include "optimizer/irpass/special_op_eliminate.h" | |||||
| #include "optimizer/irpass/specialize_transform.h" | |||||
| #include "optimizer/irpass/symbol_resolver.h" | |||||
| #include "optimizer/irpass/tile_eliminate.h" | |||||
| #include "optimizer/irpass/transpose_eliminate.h" | |||||
| #include "optimizer/opt.h" | #include "optimizer/opt.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| OptimizeIRPassLib::OptimizeIRPassLib() { | OptimizeIRPassLib::OptimizeIRPassLib() { | ||||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | |||||
| arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify", | |||||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | ||||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); | prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); | ||||
| arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); | |||||
| arithmetic_simplify2_ = | |||||
| MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul}); | |||||
| special_op_eliminate_ = | special_op_eliminate_ = | ||||
| MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||||
| MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate", | |||||
| {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | ||||
| prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | ||||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); | |||||
| adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||||
| zero_like_fill_zero_ = | |||||
| MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike); | |||||
| adjust_all_reduce_mul_add_ = | |||||
| MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||||
| // ops eliminate | // ops eliminate | ||||
| item_tuple_eliminate_ = | |||||
| MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); | |||||
| tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile); | |||||
| cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast); | |||||
| reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape); | |||||
| transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose); | |||||
| item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate", | |||||
| {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); | |||||
| tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile); | |||||
| cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); | |||||
| reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); | |||||
| transpose_eliminate_ = | |||||
| MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose); | |||||
| reduce_eliminate_ = MakeSubstitution( | reduce_eliminate_ = MakeSubstitution( | ||||
| ReduceOneEliminater(), "reduce_eliminate", | |||||
| std::make_shared<ReduceOneEliminater>(), "reduce_eliminate", | |||||
| {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | ||||
| partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); | |||||
| same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); | |||||
| check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||||
| reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | |||||
| depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); | |||||
| partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup); | |||||
| same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | |||||
| check_bprop_eliminate_ = | |||||
| MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||||
| reset_defer_inline_ = | |||||
| MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>); | |||||
| depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend); | |||||
| // Env Item Eliminate | // Env Item Eliminate | ||||
| env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | |||||
| new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); | |||||
| env_get_item_eliminate_ = | |||||
| MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | |||||
| new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem); | |||||
| incorporate_env_getitem_ = | incorporate_env_getitem_ = | ||||
| MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||||
| incorporate_env_getitem_switch_ = | |||||
| MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); | |||||
| MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||||
| incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(), | |||||
| "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); | |||||
| // Ref eliminate | // Ref eliminate | ||||
| make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); | |||||
| get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", | |||||
| make_ref_eliminate_ = | |||||
| MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | |||||
| get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", | |||||
| {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | ||||
| get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", | |||||
| get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate", | |||||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | ||||
| replace_refkey_by_param_ = | |||||
| MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); | |||||
| replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); | |||||
| replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | |||||
| IsValueNode<RefKey>, opt::FORCE_RENORM); | |||||
| replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam); | |||||
| // Gradient transforms | // Gradient transforms | ||||
| expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); | |||||
| minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||||
| expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ); | |||||
| minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||||
| // branch culling | // branch culling | ||||
| switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch); | |||||
| float_tuple_getitem_switch_ = | |||||
| MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); | |||||
| switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch); | |||||
| float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(), | |||||
| "float_tuple_getitem_switch", prim::kPrimTupleGetItem); | |||||
| float_env_getitem_switch_ = | float_env_getitem_switch_ = | ||||
| MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem); | |||||
| convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup); | |||||
| MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem); | |||||
| convert_switch_replacement_ = | |||||
| MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup); | |||||
| // Addn | // Addn | ||||
| merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN); | |||||
| addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN); | |||||
| merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN); | |||||
| addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN); | |||||
| // inline | // inline | ||||
| inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph); | |||||
| replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>); | |||||
| specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); | |||||
| inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph); | |||||
| replace_applicator_ = | |||||
| MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>); | |||||
| specialize_transform_ = | |||||
| MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph); | |||||
| // Incorporation | // Incorporation | ||||
| incorporate_getitem_set_ = | incorporate_getitem_set_ = | ||||
| MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | |||||
| incorporate_getitem_from_param_ = | |||||
| MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); | |||||
| incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); | |||||
| incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); | |||||
| MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | |||||
| incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(), | |||||
| "incorporate_getitem_from_param", IsCNodeGraphKernel); | |||||
| incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup); | |||||
| incorporate_call_switch_ = | |||||
| MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup); | |||||
| // Virtual Dataset | // Virtual Dataset | ||||
| virtual_dataset_eliminate_ = | |||||
| MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | |||||
| virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), | |||||
| "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | |||||
| // Convert | // Convert | ||||
| print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); | |||||
| print_tuple_wrapper_ = | |||||
| MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint); | |||||
| // Unused parameter eliminate | // Unused parameter eliminate | ||||
| unused_parameter_eliminate_ = | unused_parameter_eliminate_ = | ||||
| MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); | |||||
| unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); | |||||
| MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel); | |||||
| unused_output_eliminate_ = | |||||
| MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel); | |||||
| // AddN eliminate | // AddN eliminate | ||||
| addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); | |||||
| addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel); | |||||
| // Mark interface fusion | // Mark interface fusion | ||||
| mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); | |||||
| mark_interface_fusion_ = | |||||
| MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | |||||
| } | } | ||||
| ResolveIRPassLib::ResolveIRPassLib() { | ResolveIRPassLib::ResolveIRPassLib() { | ||||
| resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); | |||||
| resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); | |||||
| resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | |||||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr); | |||||
| } | } | ||||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | InferenceOptPrepareLib::InferenceOptPrepareLib() { | ||||
| grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); | |||||
| grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode); | |||||
| } | } | ||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -17,15 +17,16 @@ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ | #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ | ||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ | #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ | ||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/irpass/prim_eliminate.h" | |||||
| #include "ir/optimizer_caller.h" | |||||
| #include "ir/visitor.h" | #include "ir/visitor.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/irpass/prim_eliminate.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| FuncGraphPtr all_reduce_fg_{nullptr}; | FuncGraphPtr all_reduce_fg_{nullptr}; | ||||
| }; | }; | ||||
| class ArithmeticSimplify { | |||||
| class ArithmeticSimplify : public OptimizerCaller { | |||||
| public: | public: | ||||
| ArithmeticSimplify() | ArithmeticSimplify() | ||||
| : multiply_by_zero_or_one_(), | |||||
| tensor_multiply_by_one_(), | |||||
| add_by_zero_(), | |||||
| tensor_add_by_zero_(), | |||||
| identity_(prim::kPrimIdentity), | |||||
| opt_update_zero_tensor_(), | |||||
| constant_duplicate_mul_(), | |||||
| power_one_() { | |||||
| : multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()), | |||||
| tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()), | |||||
| add_by_zero_(std::make_shared<AddByZero>()), | |||||
| tensor_add_by_zero_(std::make_shared<TensorAddByZero>()), | |||||
| identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)), | |||||
| opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()), | |||||
| constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()), | |||||
| power_one_(std::make_shared<PowerOneEliminate>()) { | |||||
| eliminaters_.emplace_back(multiply_by_zero_or_one_); | eliminaters_.emplace_back(multiply_by_zero_or_one_); | ||||
| eliminaters_.emplace_back(tensor_multiply_by_one_); | eliminaters_.emplace_back(tensor_multiply_by_one_); | ||||
| eliminaters_.emplace_back(add_by_zero_); | eliminaters_.emplace_back(add_by_zero_); | ||||
| @@ -761,10 +762,10 @@ class ArithmeticSimplify { | |||||
| } | } | ||||
| ~ArithmeticSimplify() = default; | ~ArithmeticSimplify() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| AnfNodePtr new_node; | AnfNodePtr new_node; | ||||
| for (auto &eliminater : eliminaters_) { | for (auto &eliminater : eliminaters_) { | ||||
| new_node = eliminater(optimizer, node); | |||||
| new_node = (*eliminater)(optimizer, node); | |||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| @@ -773,15 +774,9 @@ class ArithmeticSimplify { | |||||
| } | } | ||||
| private: | private: | ||||
| MultiplyByZeroOrOne multiply_by_zero_or_one_; | |||||
| TensorMultiplyByOne tensor_multiply_by_one_; | |||||
| AddByZero add_by_zero_; | |||||
| TensorAddByZero tensor_add_by_zero_; | |||||
| PrimEliminater identity_; | |||||
| OptUpdateZeroTensor opt_update_zero_tensor_; | |||||
| ConstantDuplicateMul constant_duplicate_mul_; | |||||
| PowerOneEliminate power_one_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | |||||
| OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_, | |||||
| opt_update_zero_tensor_, constant_duplicate_mul_, power_one_; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||||
| }; | }; | ||||
| // Arithmetic Simplifications should be done after step_parallel. | // Arithmetic Simplifications should be done after step_parallel. | ||||
| @@ -789,15 +784,17 @@ class ArithmeticSimplify { | |||||
| // with shape(weight), but after step_parallel, shape of weight may be changed, so the | // with shape(weight), but after step_parallel, shape of weight may be changed, so the | ||||
| // shape of the constant tensor should also be changed. So this pass is seperated from | // shape of the constant tensor should also be changed. So this pass is seperated from | ||||
| // ArithmeticSimplify and deferred until step_parallel. | // ArithmeticSimplify and deferred until step_parallel. | ||||
| class ArithmeticSimplify2 { | |||||
| class ArithmeticSimplify2 : public OptimizerCaller { | |||||
| public: | public: | ||||
| ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } | |||||
| ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) { | |||||
| eliminaters_.emplace_back(tensor_multiply_by_zero_); | |||||
| } | |||||
| ~ArithmeticSimplify2() = default; | ~ArithmeticSimplify2() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| AnfNodePtr new_node; | AnfNodePtr new_node; | ||||
| for (auto &eliminater : eliminaters_) { | for (auto &eliminater : eliminaters_) { | ||||
| new_node = eliminater(optimizer, node); | |||||
| new_node = (*eliminater)(optimizer, node); | |||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| @@ -806,8 +803,8 @@ class ArithmeticSimplify2 { | |||||
| } | } | ||||
| private: | private: | ||||
| TensorMultiplyByZero tensor_multiply_by_zero_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | |||||
| OptimizerCallerPtr tensor_multiply_by_zero_; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -17,9 +17,9 @@ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ | #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ | ||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ | #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ | ||||
| #include "ir/visitor.h" | |||||
| #include "optimizer/irpass.h" | #include "optimizer/irpass.h" | ||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| #include "ir/visitor.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor { | |||||
| AnfNodePtr x_{nullptr}, t_{nullptr}; | AnfNodePtr x_{nullptr}, t_{nullptr}; | ||||
| }; | }; | ||||
| class CastEliminater { | |||||
| class CastEliminater : public OptimizerCaller { | |||||
| public: | public: | ||||
| CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} | CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} | ||||
| ~CastEliminater() = default; | ~CastEliminater() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| auto new_node = cast_same_type_eliminater_(optimizer, node); | auto new_node = cast_same_type_eliminater_(optimizer, node); | ||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| return new_node; | return new_node; | ||||
| @@ -17,18 +17,19 @@ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ | #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ | ||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ | #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ | ||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <unordered_map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/optimizer_caller.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor { | |||||
| bool is_match_{false}; | bool is_match_{false}; | ||||
| }; | }; | ||||
| class EnvGetItemEliminater { | |||||
| class EnvGetItemEliminater : public OptimizerCaller { | |||||
| public: | public: | ||||
| EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { | |||||
| EnvGetItemEliminater() | |||||
| : new_env_get_item_(std::make_shared<NewEnvGetItem>()), | |||||
| add_env_get_item_(std::make_shared<AddEnvGetItem>()), | |||||
| env_get_set_item_(std::make_shared<EnvGetSetItem>()) { | |||||
| eliminaters_.emplace_back(new_env_get_item_); | eliminaters_.emplace_back(new_env_get_item_); | ||||
| eliminaters_.emplace_back(add_env_get_item_); | eliminaters_.emplace_back(add_env_get_item_); | ||||
| eliminaters_.emplace_back(env_get_set_item_); | eliminaters_.emplace_back(env_get_set_item_); | ||||
| } | } | ||||
| ~EnvGetItemEliminater() = default; | ~EnvGetItemEliminater() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| AnfNodePtr new_node; | AnfNodePtr new_node; | ||||
| for (auto &eliminater : eliminaters_) { | for (auto &eliminater : eliminaters_) { | ||||
| new_node = eliminater(optimizer, node); | |||||
| new_node = (*eliminater)(optimizer, node); | |||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| @@ -246,10 +250,8 @@ class EnvGetItemEliminater { | |||||
| } | } | ||||
| private: | private: | ||||
| NewEnvGetItem new_env_get_item_; | |||||
| AddEnvGetItem add_env_get_item_; | |||||
| EnvGetSetItem env_get_set_item_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | |||||
| OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||||
| }; | }; | ||||
| // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} | // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} | ||||
| @@ -17,18 +17,20 @@ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ | #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ | ||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ | #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ | ||||
| #include <vector> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <unordered_map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <vector> | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/optimizer_caller.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| @@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||||
| internal::GetitemTransform getitem_transform_; | internal::GetitemTransform getitem_transform_; | ||||
| }; | }; | ||||
| class IncorporateGetitemSet { | |||||
| class IncorporateGetitemSet : public OptimizerCaller { | |||||
| public: | public: | ||||
| IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { | |||||
| IncorporateGetitemSet() | |||||
| : incorporate_getitem_(std::make_shared<IncorporateGetitem>()), | |||||
| incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) { | |||||
| eliminaters_.emplace_back(incorporate_getitem_); | eliminaters_.emplace_back(incorporate_getitem_); | ||||
| eliminaters_.emplace_back(incorporate_getitem_switch_); | eliminaters_.emplace_back(incorporate_getitem_switch_); | ||||
| } | } | ||||
| ~IncorporateGetitemSet() = default; | ~IncorporateGetitemSet() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| AnfNodePtr new_node; | AnfNodePtr new_node; | ||||
| for (auto &eliminater : eliminaters_) { | for (auto &eliminater : eliminaters_) { | ||||
| new_node = eliminater(optimizer, node); | |||||
| new_node = (*eliminater)(optimizer, node); | |||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| @@ -403,9 +407,8 @@ class IncorporateGetitemSet { | |||||
| } | } | ||||
| private: | private: | ||||
| IncorporateGetitem incorporate_getitem_; | |||||
| IncorporateGetitemSwitch incorporate_getitem_switch_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | |||||
| OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -17,13 +17,15 @@ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | ||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | ||||
| #include <vector> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "ir/optimizer_caller.h" | |||||
| #include "ir/visitor.h" | #include "ir/visitor.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor { | |||||
| AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; | AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; | ||||
| }; | }; | ||||
| class ItemTupleEliminater { | |||||
| class ItemTupleEliminater : public OptimizerCaller { | |||||
| public: | public: | ||||
| ItemTupleEliminater() | ItemTupleEliminater() | ||||
| : get_item_eliminater_(), | |||||
| get_item_const_eliminater_(), | |||||
| set_item_eliminater_(), | |||||
| get_set_item_eliminater_(), | |||||
| get_item_depend_reorder_() { | |||||
| : get_item_eliminater_(std::make_shared<GetitemEliminater>()), | |||||
| get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), | |||||
| set_item_eliminater_(std::make_shared<SetitemEliminater>()), | |||||
| get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()), | |||||
| get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) { | |||||
| eliminaters_.emplace_back(get_item_eliminater_); | eliminaters_.emplace_back(get_item_eliminater_); | ||||
| eliminaters_.emplace_back(get_item_const_eliminater_); | eliminaters_.emplace_back(get_item_const_eliminater_); | ||||
| eliminaters_.emplace_back(set_item_eliminater_); | eliminaters_.emplace_back(set_item_eliminater_); | ||||
| @@ -277,10 +279,10 @@ class ItemTupleEliminater { | |||||
| } | } | ||||
| ~ItemTupleEliminater() = default; | ~ItemTupleEliminater() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| AnfNodePtr new_node; | AnfNodePtr new_node; | ||||
| for (auto &eliminater : eliminaters_) { | for (auto &eliminater : eliminaters_) { | ||||
| new_node = eliminater(optimizer, node); | |||||
| new_node = (*eliminater)(optimizer, node); | |||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| @@ -289,12 +291,9 @@ class ItemTupleEliminater { | |||||
| } | } | ||||
| private: | private: | ||||
| GetitemEliminater get_item_eliminater_; | |||||
| GetitemConstEliminater get_item_const_eliminater_; | |||||
| SetitemEliminater set_item_eliminater_; | |||||
| GetSetitemEliminater get_set_item_eliminater_; | |||||
| GetitemDependReorder get_item_depend_reorder_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | |||||
| OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, | |||||
| get_item_depend_reorder_; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -19,9 +19,9 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "optimizer/optimizer.h" | |||||
| #include "optimizer/irpass.h" | |||||
| #include "ir/pattern_matcher.h" | #include "ir/pattern_matcher.h" | ||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -19,11 +19,12 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/optimizer_caller.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "pipeline/static_analysis/dshape.h" | #include "pipeline/static_analysis/dshape.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor { | |||||
| AnfNodePtr x_{nullptr}, shape_{nullptr}; | AnfNodePtr x_{nullptr}, shape_{nullptr}; | ||||
| }; | }; | ||||
| class ReshapeEliminater { | |||||
| class ReshapeEliminater : public OptimizerCaller { | |||||
| public: | public: | ||||
| ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} | ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} | ||||
| ~ReshapeEliminater() = default; | ~ReshapeEliminater() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| auto new_node = reshape_same_shape_eliminater_(optimizer, node); | auto new_node = reshape_same_shape_eliminater_(optimizer, node); | ||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| return new_node; | return new_node; | ||||
| @@ -18,31 +18,31 @@ | |||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ | #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ | ||||
| #include <securec.h> | #include <securec.h> | ||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "optimizer/irpass.h" | |||||
| #include "ir/optimizer_caller.h" | #include "ir/optimizer_caller.h" | ||||
| #include "optimizer/irpass/prim_eliminate.h" | |||||
| #include "ir/pattern_matcher.h" | |||||
| #include "ir/visitor.h" | #include "ir/visitor.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "ir/pattern_matcher.h" | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/irpass/prim_eliminate.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| class SpecialOpEliminater { | |||||
| class SpecialOpEliminater : public OptimizerCaller { | |||||
| public: | public: | ||||
| SpecialOpEliminater() | SpecialOpEliminater() | ||||
| : insert_gradient_of_(prim::kPrimInsertGradientOf), | |||||
| stop_gradient_(prim::kPrimStopGradient), | |||||
| hook_backward_(prim::kPrimHookBackward), | |||||
| print_shape_type_(prim::kPrimPrintShapeType), | |||||
| get_ref_value_(prim::kPrimGetRefValue), | |||||
| mirror_(prim::kPrimMirror), | |||||
| virtual_div_(prim::kPrimVirtualDiv) { | |||||
| : insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)), | |||||
| stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)), | |||||
| hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)), | |||||
| print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)), | |||||
| get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)), | |||||
| mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)), | |||||
| virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) { | |||||
| eliminaters_.emplace_back(insert_gradient_of_); | eliminaters_.emplace_back(insert_gradient_of_); | ||||
| eliminaters_.emplace_back(stop_gradient_); | eliminaters_.emplace_back(stop_gradient_); | ||||
| eliminaters_.emplace_back(hook_backward_); | eliminaters_.emplace_back(hook_backward_); | ||||
| @@ -53,10 +53,10 @@ class SpecialOpEliminater { | |||||
| } | } | ||||
| ~SpecialOpEliminater() = default; | ~SpecialOpEliminater() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| AnfNodePtr new_node; | AnfNodePtr new_node; | ||||
| for (auto &eliminater : eliminaters_) { | for (auto &eliminater : eliminaters_) { | ||||
| new_node = eliminater(optimizer, node); | |||||
| new_node = (*eliminater)(optimizer, node); | |||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| @@ -65,9 +65,9 @@ class SpecialOpEliminater { | |||||
| } | } | ||||
| private: | private: | ||||
| PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, | |||||
| OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, | |||||
| virtual_div_; | virtual_div_; | ||||
| std::vector<TransformFuncType> eliminaters_{}; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||||
| }; | }; | ||||
| // {PrimVirtualDataset, X} -> X | // {PrimVirtualDataset, X} -> X | ||||
| @@ -16,28 +16,27 @@ | |||||
| #include "optimizer/opt.h" | #include "optimizer/opt.h" | ||||
| #include <algorithm> | |||||
| #include <deque> | |||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <deque> | |||||
| #include <algorithm> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "utils/ordered_set.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| #include "utils/log_adapter.h" | |||||
| #include "utils/ordered_set.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| /* namespace to support opt */ | /* namespace to support opt */ | ||||
| namespace opt { | namespace opt { | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, | |||||
| const RenormAction &renorm_action) { | const RenormAction &renorm_action) { | ||||
| auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; | auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; | ||||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | return std::make_shared<Substitution>(transform, name, fn, renorm_action); | ||||
| } | } | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||||
| const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) { | const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) { | ||||
| auto fn = [prims](const AnfNodePtr &node) -> bool { | auto fn = [prims](const AnfNodePtr &node) -> bool { | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| @@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: | |||||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | return std::make_shared<Substitution>(transform, name, fn, renorm_action); | ||||
| } | } | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||||
| const PredicateFuncType &predicate, const RenormAction &renorm_action) { | const PredicateFuncType &predicate, const RenormAction &renorm_action) { | ||||
| return std::make_shared<Substitution>(transform, name, predicate, renorm_action); | return std::make_shared<Substitution>(transform, name, predicate, renorm_action); | ||||
| } | } | ||||
| AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { | |||||
| AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| double t = GetTime(); | double t = GetTime(); | ||||
| #endif | #endif | ||||
| AnfNodePtr result = transform_(optimizer, node); | |||||
| AnfNodePtr result = (*transform_)(optimizer, node); | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| if (optimizer != nullptr) { | if (optimizer != nullptr) { | ||||
| auto time = GetTime(); | auto time = GetTime(); | ||||
| @@ -17,24 +17,18 @@ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ | #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ | ||||
| #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ | #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ | ||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/optimizer_caller.h" | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| /* namespace to support opt */ | /* namespace to support opt */ | ||||
| namespace opt { | namespace opt { | ||||
| class Optimizer; | |||||
| using OptimizerPtr = std::shared_ptr<Optimizer>; | |||||
| using OptimizerWeakPtr = std::weak_ptr<Optimizer>; | |||||
| using PredicateFuncType = std::function<bool(const AnfNodePtr &)>; | |||||
| using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>; | |||||
| // Define the interaction mode between an Optimize pass and Renormalize pass | // Define the interaction mode between an Optimize pass and Renormalize pass | ||||
| // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed | // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed | ||||
| @@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; | |||||
| class Substitution { | class Substitution { | ||||
| public: | public: | ||||
| TransformFuncType transform_{nullptr}; | |||||
| OptimizerCallerPtr transform_; | |||||
| std::string name_; | std::string name_; | ||||
| PredicateFuncType predicate_{nullptr}; | PredicateFuncType predicate_{nullptr}; | ||||
| // an enum to mark this Substitution relation to renormalize pass | // an enum to mark this Substitution relation to renormalize pass | ||||
| RenormAction renorm_action_; | RenormAction renorm_action_; | ||||
| Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, | |||||
| Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, | |||||
| const RenormAction &renorm_action) | const RenormAction &renorm_action) | ||||
| : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} | : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} | ||||
| ~Substitution() = default; | ~Substitution() = default; | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); | |||||
| }; | }; | ||||
| using SubstitutionPtr = std::shared_ptr<Substitution>; | using SubstitutionPtr = std::shared_ptr<Substitution>; | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, | |||||
| const RenormAction &action_renorm = CHECK_RENORM); | const RenormAction &action_renorm = CHECK_RENORM); | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||||
| const std::vector<PrimitivePtr> &prims, | const std::vector<PrimitivePtr> &prims, | ||||
| const RenormAction &action_renorm = CHECK_RENORM); | const RenormAction &action_renorm = CHECK_RENORM); | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||||
| const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); | const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); | ||||
| class SubstitutionList { | class SubstitutionList { | ||||
| @@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common { | |||||
| }; | }; | ||||
| void SetUp() { | void SetUp() { | ||||
| elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd); | |||||
| elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R); | |||||
| idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P); | |||||
| Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q); | |||||
| elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd); | |||||
| elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R); | |||||
| idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P); | |||||
| Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q); | |||||
| } | } | ||||
| bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { | bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { | ||||