/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "optimizer/irpass.h" #include "optimizer/irpass/arithmetic_simplify.h" #include "optimizer/irpass/branch_culling.h" #include "optimizer/irpass/cast_eliminate.h" #include "optimizer/irpass/convert.h" #include "optimizer/irpass/env_item_eliminate.h" #include "optimizer/irpass/grad_var_prepare.h" #include "optimizer/irpass/gradient_eliminate.h" #include "optimizer/irpass/inline.h" #include "optimizer/irpass/incorporate_call.h" #include "optimizer/irpass/incorporate_getitem.h" #include "optimizer/irpass/item_tuple_eliminate.h" #include "optimizer/irpass/mark_interface_fusion.h" #include "optimizer/irpass/merge_addn.h" #include "optimizer/irpass/minmax_grad.h" #include "optimizer/irpass/param_replace.h" #include "optimizer/irpass/partial_eliminate.h" #include "optimizer/irpass/reduce_eliminate.h" #include "optimizer/irpass/ref_eliminate.h" #include "optimizer/irpass/reshape_eliminate.h" #include "optimizer/irpass/special_op_eliminate.h" #include "optimizer/irpass/specialize_transform.h" #include "optimizer/irpass/symbol_resolver.h" #include "optimizer/irpass/tile_eliminate.h" #include "optimizer/irpass/transpose_eliminate.h" #include "optimizer/opt.h" #include "optimizer/irpass/indexed_slices_eliminate.h" namespace mindspore { namespace opt { namespace irpass { OptimizeIRPassLib::OptimizeIRPassLib() { arithmetic_simplify_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); arithmetic_simplify2_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify2", {prim::kPrimMul}); special_op_eliminate_ = MakeSubstitution(std::make_shared(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv}); zero_like_fill_zero_ = MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); adjust_all_reduce_mul_add_ = MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); // ops eliminate item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); transpose_eliminate_ = MakeSubstitution(std::make_shared(), "transpose_eliminate", prim::kPrimTranspose); reduce_eliminate_ = MakeSubstitution( std::make_shared(), "reduce_eliminate", {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); check_bprop_eliminate_ = MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); reset_defer_inline_ = MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); // Env Item Eliminate env_get_item_eliminate_ = MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_ = MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); // Ref eliminate make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", IsValueNode, opt::FORCE_RENORM); replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); // Gradient transforms expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); // branch culling switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); float_env_getitem_switch_ = MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); convert_switch_replacement_ = MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); // Addn merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); // inline inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); replace_applicator_ = MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); specialize_transform_ = MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); // Incorporation incorporate_getitem_set_ = MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared(), "incorporate_getitem_from_param", IsCNodeGraphKernel); incorporate_call_ = MakeSubstitution(std::make_shared(), "incorporate_call", IsCNodeDup); incorporate_call_switch_ = MakeSubstitution(std::make_shared(), "incorporate_call_switch", IsCNodeDup); // Virtual Dataset virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); // Convert print_tuple_wrapper_ = MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); // Unused parameter eliminate unused_parameter_eliminate_ = MakeSubstitution(std::make_shared(), "unused_parameter_eliminate", IsCNodeGraphKernel); unused_output_eliminate_ = MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); // AddN eliminate addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); // Mark interface fusion mark_interface_fusion_ = MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); // IndexedSlices Eliminate indexed_slices_eliminate_ = MakeSubstitution( std::make_shared(), "indexed_slices_eliminate", {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); } ResolveIRPassLib::ResolveIRPassLib() { resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); } InferenceOptPrepareLib::InferenceOptPrepareLib() { grad_var_prepare_ = MakeSubstitution(std::make_shared(), "grad_var_prepare", IsCNode); } } // namespace irpass } // namespace opt } // namespace mindspore