You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

irpass.cc 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "optimizer/irpass.h"
  17. #include <string>
  18. #include "optimizer/irpass/symbol_resolver.h"
  19. #include "optimizer/irpass/arithmetic_simplify.h"
  20. #include "optimizer/irpass/special_op_eliminate.h"
  21. #include "optimizer/irpass/item_tuple_eliminate.h"
  22. #include "optimizer/irpass/env_item_eliminate.h"
  23. #include "optimizer/irpass/tile_eliminate.h"
  24. #include "optimizer/irpass/cast_eliminate.h"
  25. #include "optimizer/irpass/reshape_eliminate.h"
  26. #include "optimizer/irpass/transpose_eliminate.h"
  27. #include "optimizer/irpass/reduce_eliminate.h"
  28. #include "optimizer/irpass/partial_eliminate.h"
  29. #include "optimizer/irpass/ref_eliminate.h"
  30. #include "optimizer/irpass/merge_addn.h"
  31. #include "optimizer/irpass/branch_culling.h"
  32. #include "optimizer/irpass/gradient_eliminate.h"
  33. #include "optimizer/irpass/minmax_grad.h"
  34. #include "optimizer/irpass/inline.h"
  35. #include "optimizer/irpass/convert.h"
  36. #include "optimizer/irpass/specialize_transform.h"
  37. #include "optimizer/irpass/incorporate_getitem.h"
  38. #include "optimizer/irpass/incorporate_call.h"
  39. #include "optimizer/irpass/grad_var_prepare.h"
  40. #include "optimizer/irpass/param_replace.h"
  41. namespace mindspore {
  42. namespace opt {
  43. namespace irpass {
  44. OptimizeIRPassLib::OptimizeIRPassLib() {
  45. arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
  46. {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
  47. prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
  48. special_op_eliminate_ =
  49. MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
  50. {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
  51. prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
  52. zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
  53. adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
  54. // ops eliminate
  55. item_tuple_eliminate_ =
  56. MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
  57. tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile);
  58. cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast);
  59. reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape);
  60. transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose);
  61. reduce_eliminate_ = MakeSubstitution(
  62. ReduceOneEliminater(), "reduce_eliminate",
  63. {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
  64. partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup);
  65. same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
  66. check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
  67. reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
  68. // Env Item Eliminate
  69. new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem);
  70. add_env_get_item_ = MakeSubstitution(AddEnvGetItem(), "add_env_get_item", prim::kPrimEnvGetItem);
  71. env_get_set_item_ = MakeSubstitution(EnvGetSetItem(), "env_get_set_item", prim::kPrimEnvGetItem);
  72. incorporate_env_getitem_ =
  73. MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
  74. incorporate_env_getitem_switch_ =
  75. MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
  76. // Ref eliminate
  77. make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
  78. get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
  79. {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
  80. replace_refkey_by_param_ =
  81. MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
  82. replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
  83. // Gradient transforms
  84. expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
  85. stop_gradient_eliminate_ =
  86. MakeSubstitution(StopGradientEliminater(), "stop_gradient_eliminate", prim::kPrimStopGradient);
  87. minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
  88. // branch culling
  89. switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch);
  90. float_tuple_getitem_switch_ =
  91. MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
  92. float_env_getitem_switch_ =
  93. MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
  94. convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup);
  95. // Addn
  96. merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN);
  97. addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN);
  98. // inline
  99. inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph);
  100. replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>);
  101. specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph);
  102. // Incorporation
  103. incorporate_getitem_ = MakeSubstitution(IncorporateGetitem(), "incorporate_getitem", prim::kPrimTupleGetItem);
  104. incorporate_getitem_switch_ =
  105. MakeSubstitution(IncorporateGetitemSwitch(), "incorporate_getitem_switch", prim::kPrimTupleGetItem);
  106. incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
  107. incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
  108. // Virtual Dataset
  109. virtual_dataset_eliminate_ =
  110. MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
  111. // Convert
  112. print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint);
  113. }
  114. ResolveIRPassLib::ResolveIRPassLib() {
  115. resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve);
  116. resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr);
  117. }
  118. InferenceOptPrepareLib::InferenceOptPrepareLib() {
  119. grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
  120. }
  121. } // namespace irpass
  122. } // namespace opt
  123. } // namespace mindspore