| @@ -17,13 +17,23 @@ | |||
| #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | |||
| #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "optimizer/opt.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class Optimizer; | |||
| using OptimizerPtr = std::shared_ptr<Optimizer>; | |||
| using OptimizerWeakPtr = std::weak_ptr<Optimizer>; | |||
| using PredicateFuncType = std::function<bool(const AnfNodePtr &)>; | |||
| } // namespace opt | |||
| class OptimizerCaller { | |||
| public: | |||
| virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } | |||
| }; | |||
| using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | |||
| @@ -14,140 +14,154 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "optimizer/irpass.h" | |||
| #include <string> | |||
| #include "optimizer/irpass/symbol_resolver.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/irpass/arithmetic_simplify.h" | |||
| #include "optimizer/irpass/special_op_eliminate.h" | |||
| #include "optimizer/irpass/item_tuple_eliminate.h" | |||
| #include "optimizer/irpass/env_item_eliminate.h" | |||
| #include "optimizer/irpass/tile_eliminate.h" | |||
| #include "optimizer/irpass/cast_eliminate.h" | |||
| #include "optimizer/irpass/reshape_eliminate.h" | |||
| #include "optimizer/irpass/transpose_eliminate.h" | |||
| #include "optimizer/irpass/reduce_eliminate.h" | |||
| #include "optimizer/irpass/partial_eliminate.h" | |||
| #include "optimizer/irpass/ref_eliminate.h" | |||
| #include "optimizer/irpass/merge_addn.h" | |||
| #include "optimizer/irpass/branch_culling.h" | |||
| #include "optimizer/irpass/cast_eliminate.h" | |||
| #include "optimizer/irpass/convert.h" | |||
| #include "optimizer/irpass/env_item_eliminate.h" | |||
| #include "optimizer/irpass/grad_var_prepare.h" | |||
| #include "optimizer/irpass/gradient_eliminate.h" | |||
| #include "optimizer/irpass/minmax_grad.h" | |||
| #include "optimizer/irpass/inline.h" | |||
| #include "optimizer/irpass/convert.h" | |||
| #include "optimizer/irpass/specialize_transform.h" | |||
| #include "optimizer/irpass/incorporate_getitem.h" | |||
| #include "optimizer/irpass/incorporate_call.h" | |||
| #include "optimizer/irpass/grad_var_prepare.h" | |||
| #include "optimizer/irpass/param_replace.h" | |||
| #include "optimizer/irpass/incorporate_getitem.h" | |||
| #include "optimizer/irpass/item_tuple_eliminate.h" | |||
| #include "optimizer/irpass/mark_interface_fusion.h" | |||
| #include "optimizer/irpass/merge_addn.h" | |||
| #include "optimizer/irpass/minmax_grad.h" | |||
| #include "optimizer/irpass/param_replace.h" | |||
| #include "optimizer/irpass/partial_eliminate.h" | |||
| #include "optimizer/irpass/reduce_eliminate.h" | |||
| #include "optimizer/irpass/ref_eliminate.h" | |||
| #include "optimizer/irpass/reshape_eliminate.h" | |||
| #include "optimizer/irpass/special_op_eliminate.h" | |||
| #include "optimizer/irpass/specialize_transform.h" | |||
| #include "optimizer/irpass/symbol_resolver.h" | |||
| #include "optimizer/irpass/tile_eliminate.h" | |||
| #include "optimizer/irpass/transpose_eliminate.h" | |||
| #include "optimizer/opt.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | |||
| arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify", | |||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | |||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); | |||
| arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); | |||
| arithmetic_simplify2_ = | |||
| MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul}); | |||
| special_op_eliminate_ = | |||
| MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | |||
| prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); | |||
| adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||
| zero_like_fill_zero_ = | |||
| MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike); | |||
| adjust_all_reduce_mul_add_ = | |||
| MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||
| // ops eliminate | |||
| item_tuple_eliminate_ = | |||
| MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); | |||
| tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile); | |||
| cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast); | |||
| reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape); | |||
| transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose); | |||
| item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate", | |||
| {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); | |||
| tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile); | |||
| cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); | |||
| reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); | |||
| transpose_eliminate_ = | |||
| MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose); | |||
| reduce_eliminate_ = MakeSubstitution( | |||
| ReduceOneEliminater(), "reduce_eliminate", | |||
| std::make_shared<ReduceOneEliminater>(), "reduce_eliminate", | |||
| {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | |||
| partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); | |||
| same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); | |||
| check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||
| reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | |||
| depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); | |||
| partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup); | |||
| same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | |||
| check_bprop_eliminate_ = | |||
| MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||
| reset_defer_inline_ = | |||
| MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>); | |||
| depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend); | |||
| // Env Item Eliminate | |||
| env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | |||
| new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); | |||
| env_get_item_eliminate_ = | |||
| MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | |||
| new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_ = | |||
| MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_switch_ = | |||
| MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(), | |||
| "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| // Ref eliminate | |||
| make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); | |||
| get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", | |||
| make_ref_eliminate_ = | |||
| MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | |||
| get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", | |||
| {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||
| get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", | |||
| get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate", | |||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||
| replace_refkey_by_param_ = | |||
| MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); | |||
| replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | |||
| IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam); | |||
| // Gradient transforms | |||
| expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); | |||
| minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||
| expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ); | |||
| minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||
| // branch culling | |||
| switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch); | |||
| float_tuple_getitem_switch_ = | |||
| MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); | |||
| switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch); | |||
| float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(), | |||
| "float_tuple_getitem_switch", prim::kPrimTupleGetItem); | |||
| float_env_getitem_switch_ = | |||
| MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup); | |||
| MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| convert_switch_replacement_ = | |||
| MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup); | |||
| // Addn | |||
| merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN); | |||
| addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN); | |||
| merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN); | |||
| addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN); | |||
| // inline | |||
| inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph); | |||
| replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>); | |||
| specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); | |||
| inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph); | |||
| replace_applicator_ = | |||
| MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>); | |||
| specialize_transform_ = | |||
| MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph); | |||
| // Incorporation | |||
| incorporate_getitem_set_ = | |||
| MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | |||
| incorporate_getitem_from_param_ = | |||
| MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); | |||
| incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); | |||
| incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); | |||
| MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | |||
| incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(), | |||
| "incorporate_getitem_from_param", IsCNodeGraphKernel); | |||
| incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup); | |||
| incorporate_call_switch_ = | |||
| MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup); | |||
| // Virtual Dataset | |||
| virtual_dataset_eliminate_ = | |||
| MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | |||
| virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), | |||
| "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | |||
| // Convert | |||
| print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); | |||
| print_tuple_wrapper_ = | |||
| MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint); | |||
| // Unused parameter eliminate | |||
| unused_parameter_eliminate_ = | |||
| MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); | |||
| unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); | |||
| MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel); | |||
| unused_output_eliminate_ = | |||
| MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel); | |||
| // AddN eliminate | |||
| addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); | |||
| addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel); | |||
| // Mark interface fusion | |||
| mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); | |||
| mark_interface_fusion_ = | |||
| MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | |||
| } | |||
| ResolveIRPassLib::ResolveIRPassLib() { | |||
| resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); | |||
| resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); | |||
| resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | |||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr); | |||
| } | |||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | |||
| grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); | |||
| grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode); | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -17,15 +17,16 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "optimizer/optimizer.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/irpass/prim_eliminate.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/irpass/prim_eliminate.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor { | |||
| FuncGraphPtr all_reduce_fg_{nullptr}; | |||
| }; | |||
| class ArithmeticSimplify { | |||
| class ArithmeticSimplify : public OptimizerCaller { | |||
| public: | |||
| ArithmeticSimplify() | |||
| : multiply_by_zero_or_one_(), | |||
| tensor_multiply_by_one_(), | |||
| add_by_zero_(), | |||
| tensor_add_by_zero_(), | |||
| identity_(prim::kPrimIdentity), | |||
| opt_update_zero_tensor_(), | |||
| constant_duplicate_mul_(), | |||
| power_one_() { | |||
| : multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()), | |||
| tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()), | |||
| add_by_zero_(std::make_shared<AddByZero>()), | |||
| tensor_add_by_zero_(std::make_shared<TensorAddByZero>()), | |||
| identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)), | |||
| opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()), | |||
| constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()), | |||
| power_one_(std::make_shared<PowerOneEliminate>()) { | |||
| eliminaters_.emplace_back(multiply_by_zero_or_one_); | |||
| eliminaters_.emplace_back(tensor_multiply_by_one_); | |||
| eliminaters_.emplace_back(add_by_zero_); | |||
| @@ -761,10 +762,10 @@ class ArithmeticSimplify { | |||
| } | |||
| ~ArithmeticSimplify() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -773,15 +774,9 @@ class ArithmeticSimplify { | |||
| } | |||
| private: | |||
| MultiplyByZeroOrOne multiply_by_zero_or_one_; | |||
| TensorMultiplyByOne tensor_multiply_by_one_; | |||
| AddByZero add_by_zero_; | |||
| TensorAddByZero tensor_add_by_zero_; | |||
| PrimEliminater identity_; | |||
| OptUpdateZeroTensor opt_update_zero_tensor_; | |||
| ConstantDuplicateMul constant_duplicate_mul_; | |||
| PowerOneEliminate power_one_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_, | |||
| opt_update_zero_tensor_, constant_duplicate_mul_, power_one_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| // Arithmetic Simplifications should be done after step_parallel. | |||
| @@ -789,15 +784,17 @@ class ArithmeticSimplify { | |||
| // with shape(weight), but after step_parallel, shape of weight may be changed, so the | |||
| // shape of the constant tensor should also be changed. So this pass is seperated from | |||
| // ArithmeticSimplify and deferred until step_parallel. | |||
| class ArithmeticSimplify2 { | |||
| class ArithmeticSimplify2 : public OptimizerCaller { | |||
| public: | |||
| ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } | |||
| ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) { | |||
| eliminaters_.emplace_back(tensor_multiply_by_zero_); | |||
| } | |||
| ~ArithmeticSimplify2() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -806,8 +803,8 @@ class ArithmeticSimplify2 { | |||
| } | |||
| private: | |||
| TensorMultiplyByZero tensor_multiply_by_zero_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr tensor_multiply_by_zero_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -17,9 +17,9 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ | |||
| #include "ir/visitor.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}, t_{nullptr}; | |||
| }; | |||
| class CastEliminater { | |||
| class CastEliminater : public OptimizerCaller { | |||
| public: | |||
| CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} | |||
| ~CastEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| auto new_node = cast_same_type_eliminater_(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| @@ -17,18 +17,19 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "utils/symbolic.h" | |||
| namespace mindspore { | |||
| @@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor { | |||
| bool is_match_{false}; | |||
| }; | |||
| class EnvGetItemEliminater { | |||
| class EnvGetItemEliminater : public OptimizerCaller { | |||
| public: | |||
| EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { | |||
| EnvGetItemEliminater() | |||
| : new_env_get_item_(std::make_shared<NewEnvGetItem>()), | |||
| add_env_get_item_(std::make_shared<AddEnvGetItem>()), | |||
| env_get_set_item_(std::make_shared<EnvGetSetItem>()) { | |||
| eliminaters_.emplace_back(new_env_get_item_); | |||
| eliminaters_.emplace_back(add_env_get_item_); | |||
| eliminaters_.emplace_back(env_get_set_item_); | |||
| } | |||
| ~EnvGetItemEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -246,10 +250,8 @@ class EnvGetItemEliminater { | |||
| } | |||
| private: | |||
| NewEnvGetItem new_env_get_item_; | |||
| AddEnvGetItem add_env_get_item_; | |||
| EnvGetSetItem env_get_set_item_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} | |||
| @@ -17,18 +17,20 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| @@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||
| internal::GetitemTransform getitem_transform_; | |||
| }; | |||
| class IncorporateGetitemSet { | |||
| class IncorporateGetitemSet : public OptimizerCaller { | |||
| public: | |||
| IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { | |||
| IncorporateGetitemSet() | |||
| : incorporate_getitem_(std::make_shared<IncorporateGetitem>()), | |||
| incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) { | |||
| eliminaters_.emplace_back(incorporate_getitem_); | |||
| eliminaters_.emplace_back(incorporate_getitem_switch_); | |||
| } | |||
| ~IncorporateGetitemSet() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -403,9 +407,8 @@ class IncorporateGetitemSet { | |||
| } | |||
| private: | |||
| IncorporateGetitem incorporate_getitem_; | |||
| IncorporateGetitemSwitch incorporate_getitem_switch_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -17,13 +17,15 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; | |||
| }; | |||
| class ItemTupleEliminater { | |||
| class ItemTupleEliminater : public OptimizerCaller { | |||
| public: | |||
| ItemTupleEliminater() | |||
| : get_item_eliminater_(), | |||
| get_item_const_eliminater_(), | |||
| set_item_eliminater_(), | |||
| get_set_item_eliminater_(), | |||
| get_item_depend_reorder_() { | |||
| : get_item_eliminater_(std::make_shared<GetitemEliminater>()), | |||
| get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), | |||
| set_item_eliminater_(std::make_shared<SetitemEliminater>()), | |||
| get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()), | |||
| get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) { | |||
| eliminaters_.emplace_back(get_item_eliminater_); | |||
| eliminaters_.emplace_back(get_item_const_eliminater_); | |||
| eliminaters_.emplace_back(set_item_eliminater_); | |||
| @@ -277,10 +279,10 @@ class ItemTupleEliminater { | |||
| } | |||
| ~ItemTupleEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -289,12 +291,9 @@ class ItemTupleEliminater { | |||
| } | |||
| private: | |||
| GetitemEliminater get_item_eliminater_; | |||
| GetitemConstEliminater get_item_const_eliminater_; | |||
| SetitemEliminater set_item_eliminater_; | |||
| GetSetitemEliminater get_set_item_eliminater_; | |||
| GetitemDependReorder get_item_depend_reorder_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, | |||
| get_item_depend_reorder_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -19,9 +19,9 @@ | |||
| #include <memory> | |||
| #include "optimizer/optimizer.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "ir/pattern_matcher.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -19,11 +19,12 @@ | |||
| #include <vector> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "pipeline/static_analysis/dshape.h" | |||
| namespace mindspore { | |||
| @@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}, shape_{nullptr}; | |||
| }; | |||
| class ReshapeEliminater { | |||
| class ReshapeEliminater : public OptimizerCaller { | |||
| public: | |||
| ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} | |||
| ~ReshapeEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| auto new_node = reshape_same_shape_eliminater_(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| @@ -18,31 +18,31 @@ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ | |||
| #include <securec.h> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "optimizer/optimizer.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "optimizer/irpass/prim_eliminate.h" | |||
| #include "ir/pattern_matcher.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/pattern_matcher.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/irpass/prim_eliminate.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| class SpecialOpEliminater { | |||
| class SpecialOpEliminater : public OptimizerCaller { | |||
| public: | |||
| SpecialOpEliminater() | |||
| : insert_gradient_of_(prim::kPrimInsertGradientOf), | |||
| stop_gradient_(prim::kPrimStopGradient), | |||
| hook_backward_(prim::kPrimHookBackward), | |||
| print_shape_type_(prim::kPrimPrintShapeType), | |||
| get_ref_value_(prim::kPrimGetRefValue), | |||
| mirror_(prim::kPrimMirror), | |||
| virtual_div_(prim::kPrimVirtualDiv) { | |||
| : insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)), | |||
| stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)), | |||
| hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)), | |||
| print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)), | |||
| get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)), | |||
| mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)), | |||
| virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) { | |||
| eliminaters_.emplace_back(insert_gradient_of_); | |||
| eliminaters_.emplace_back(stop_gradient_); | |||
| eliminaters_.emplace_back(hook_backward_); | |||
| @@ -53,10 +53,10 @@ class SpecialOpEliminater { | |||
| } | |||
| ~SpecialOpEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -65,9 +65,9 @@ class SpecialOpEliminater { | |||
| } | |||
| private: | |||
| PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, | |||
| OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, | |||
| virtual_div_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| // {PrimVirtualDataset, X} -> X | |||
| @@ -16,28 +16,27 @@ | |||
| #include "optimizer/opt.h" | |||
| #include <algorithm> | |||
| #include <deque> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include <deque> | |||
| #include <algorithm> | |||
| #include "ir/anf.h" | |||
| #include "ir/manager.h" | |||
| #include "utils/ordered_set.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ordered_set.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| namespace opt { | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, | |||
| const RenormAction &renorm_action) { | |||
| auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; | |||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | |||
| } | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) { | |||
| auto fn = [prims](const AnfNodePtr &node) -> bool { | |||
| if (!node->isa<CNode>()) { | |||
| @@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: | |||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | |||
| } | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const PredicateFuncType &predicate, const RenormAction &renorm_action) { | |||
| return std::make_shared<Substitution>(transform, name, predicate, renorm_action); | |||
| } | |||
| AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { | |||
| AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| #ifdef ENABLE_PROFILE | |||
| double t = GetTime(); | |||
| #endif | |||
| AnfNodePtr result = transform_(optimizer, node); | |||
| AnfNodePtr result = (*transform_)(optimizer, node); | |||
| #ifdef ENABLE_PROFILE | |||
| if (optimizer != nullptr) { | |||
| auto time = GetTime(); | |||
| @@ -17,24 +17,18 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| namespace opt { | |||
| class Optimizer; | |||
| using OptimizerPtr = std::shared_ptr<Optimizer>; | |||
| using OptimizerWeakPtr = std::weak_ptr<Optimizer>; | |||
| using PredicateFuncType = std::function<bool(const AnfNodePtr &)>; | |||
| using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>; | |||
| // Define the interaction mode between an Optimize pass and Renormalize pass | |||
| // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed | |||
| @@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; | |||
| class Substitution { | |||
| public: | |||
| TransformFuncType transform_{nullptr}; | |||
| OptimizerCallerPtr transform_; | |||
| std::string name_; | |||
| PredicateFuncType predicate_{nullptr}; | |||
| // an enum to mark this Substitution relation to renormalize pass | |||
| RenormAction renorm_action_; | |||
| Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, | |||
| Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, | |||
| const RenormAction &renorm_action) | |||
| : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} | |||
| ~Substitution() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); | |||
| }; | |||
| using SubstitutionPtr = std::shared_ptr<Substitution>; | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, | |||
| const RenormAction &action_renorm = CHECK_RENORM); | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const std::vector<PrimitivePtr> &prims, | |||
| const RenormAction &action_renorm = CHECK_RENORM); | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); | |||
| class SubstitutionList { | |||
| @@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common { | |||
| }; | |||
| void SetUp() { | |||
| elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd); | |||
| elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R); | |||
| idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P); | |||
| Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q); | |||
| elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd); | |||
| elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R); | |||
| idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P); | |||
| Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q); | |||
| } | |||
| bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { | |||