You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

irpass.cc 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include "optimizer/irpass.h"
  18. #include "optimizer/irpass/arithmetic_simplify.h"
  19. #include "optimizer/irpass/branch_culling.h"
  20. #include "optimizer/irpass/cast_eliminate.h"
  21. #include "optimizer/irpass/convert.h"
  22. #include "optimizer/irpass/env_item_eliminate.h"
  23. #include "optimizer/irpass/grad_var_prepare.h"
  24. #include "optimizer/irpass/gradient_eliminate.h"
  25. #include "optimizer/irpass/inline.h"
  26. #include "optimizer/irpass/incorporate_call.h"
  27. #include "optimizer/irpass/incorporate_getitem.h"
  28. #include "optimizer/irpass/item_tuple_eliminate.h"
  29. #include "optimizer/irpass/mark_interface_fusion.h"
  30. #include "optimizer/irpass/merge_addn.h"
  31. #include "optimizer/irpass/minmax_grad.h"
  32. #include "optimizer/irpass/param_replace.h"
  33. #include "optimizer/irpass/partial_eliminate.h"
  34. #include "optimizer/irpass/reduce_eliminate.h"
  35. #include "optimizer/irpass/ref_eliminate.h"
  36. #include "optimizer/irpass/reshape_eliminate.h"
  37. #include "optimizer/irpass/special_op_eliminate.h"
  38. #include "optimizer/irpass/specialize_transform.h"
  39. #include "optimizer/irpass/symbol_resolver.h"
  40. #include "optimizer/irpass/tile_eliminate.h"
  41. #include "optimizer/irpass/transpose_eliminate.h"
  42. #include "optimizer/opt.h"
  43. #include "optimizer/irpass/indexed_slices_eliminate.h"
  44. namespace mindspore {
  45. namespace opt {
  46. namespace irpass {
  47. OptimizeIRPassLib::OptimizeIRPassLib() {
  48. arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
  49. {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
  50. prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
  51. arithmetic_simplify2_ =
  52. MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
  53. special_op_eliminate_ =
  54. MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
  55. {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
  56. prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv});
  57. zero_like_fill_zero_ =
  58. MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
  59. adjust_all_reduce_mul_add_ =
  60. MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
  61. // ops eliminate
  62. item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
  63. {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
  64. tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile);
  65. cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
  66. reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
  67. transpose_eliminate_ =
  68. MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
  69. reduce_eliminate_ = MakeSubstitution(
  70. std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
  71. {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
  72. partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
  73. same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
  74. check_bprop_eliminate_ =
  75. MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
  76. reset_defer_inline_ =
  77. MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
  78. depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
  79. // Env Item Eliminate
  80. env_get_item_eliminate_ =
  81. MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
  82. new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
  83. incorporate_env_getitem_ =
  84. MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
  85. incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
  86. "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
  87. // Ref eliminate
  88. make_ref_eliminate_ =
  89. MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
  90. get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
  91. {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
  92. get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
  93. {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
  94. replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
  95. IsValueNode<RefKey>, opt::FORCE_RENORM);
  96. replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
  97. // Gradient transforms
  98. expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
  99. minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
  100. // branch culling
  101. switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
  102. float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
  103. "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
  104. float_env_getitem_switch_ =
  105. MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
  106. convert_switch_replacement_ =
  107. MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup);
  108. // Addn
  109. merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
  110. addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
  111. // inline
  112. inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
  113. replace_applicator_ =
  114. MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
  115. specialize_transform_ =
  116. MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
  117. // Incorporation
  118. incorporate_getitem_set_ =
  119. MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
  120. incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
  121. "incorporate_getitem_from_param", IsCNodeGraphKernel);
  122. incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
  123. incorporate_call_switch_ =
  124. MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
  125. // Virtual Dataset
  126. virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
  127. "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
  128. // Convert
  129. print_tuple_wrapper_ =
  130. MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
  131. // Unused parameter eliminate
  132. unused_parameter_eliminate_ =
  133. MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
  134. unused_output_eliminate_ =
  135. MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
  136. // AddN eliminate
  137. addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
  138. // Mark interface fusion
  139. mark_interface_fusion_ =
  140. MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
  141. // IndexedSlices Eliminate
  142. indexed_slices_eliminate_ = MakeSubstitution(
  143. std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
  144. {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
  145. }
  146. ResolveIRPassLib::ResolveIRPassLib() {
  147. resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
  148. resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
  149. }
  150. InferenceOptPrepareLib::InferenceOptPrepareLib() {
  151. grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
  152. }
  153. } // namespace irpass
  154. } // namespace opt
  155. } // namespace mindspore