You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

irpass.cc 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. /**
  2. * Copyright 2020-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/irpass.h"
  17. #include "frontend/optimizer/irpass/arithmetic_simplify.h"
  18. #include "frontend/optimizer/irpass/branch_culling.h"
  19. #include "frontend/optimizer/irpass/cast_eliminate.h"
  20. #include "frontend/optimizer/irpass/convert.h"
  21. #include "frontend/optimizer/irpass/environ_eliminate.h"
  22. #include "frontend/optimizer/irpass/grad_var_prepare.h"
  23. #include "frontend/optimizer/irpass/gradient_eliminate.h"
  24. #include "frontend/optimizer/irpass/inline.h"
  25. #include "frontend/optimizer/irpass/updatestate_eliminate.h"
  26. #include "frontend/optimizer/irpass/load_eliminate.h"
  27. #include "frontend/optimizer/irpass/stopgrad_eliminate.h"
  28. #include "frontend/optimizer/irpass/incorporate_call.h"
  29. #include "frontend/optimizer/irpass/incorporate_getitem.h"
  30. #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
  31. #include "frontend/optimizer/irpass/merge_addn.h"
  32. #include "frontend/optimizer/irpass/accumulaten_eliminate.h"
  33. #include "frontend/optimizer/irpass/less_batch_normalization.h"
  34. #include "frontend/optimizer/irpass/minmax_grad.h"
  35. #include "frontend/optimizer/irpass/param_replace.h"
  36. #include "frontend/optimizer/irpass/partial_eliminate.h"
  37. #include "frontend/optimizer/irpass/reduce_eliminate.h"
  38. #include "frontend/optimizer/irpass/ref_eliminate.h"
  39. #include "frontend/optimizer/irpass/reshape_eliminate.h"
  40. #include "frontend/optimizer/irpass/special_op_eliminate.h"
  41. #include "frontend/optimizer/irpass/specialize_transform.h"
  42. #include "frontend/optimizer/irpass/symbol_resolver.h"
  43. #include "frontend/optimizer/irpass/tile_eliminate.h"
  44. #include "frontend/optimizer/irpass/transpose_eliminate.h"
  45. #include "frontend/optimizer/irpass/value_based_eliminate.h"
  46. #include "frontend/optimizer/opt.h"
  47. #include "frontend/optimizer/irpass/row_tensor_eliminate.h"
  48. #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
  49. #include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
  50. #include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
  51. #include "frontend/optimizer/irpass/recompute_prepare.h"
  52. #include "frontend/optimizer/irpass/real_op_eliminate.h"
  53. namespace mindspore {
  54. namespace opt {
  55. namespace irpass {
  56. OptimizeIRPassLib::OptimizeIRPassLib() {
  57. arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
  58. {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimAdd,
  59. prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
  60. arithmetic_simplify2_ =
  61. MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
  62. special_op_eliminate_ =
  63. MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
  64. {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
  65. prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv});
  66. pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup);
  67. zero_like_fill_zero_ =
  68. MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
  69. adjust_all_reduce_mul_add_ =
  70. MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
  71. float_depend_g_call_ = MakeSubstitution(std::make_shared<FloatDependGCall>(), "float_depend_g_call", IsCNodeDup);
  72. // ops eliminate
  73. tuple_list_get_item_eliminator_ =
  74. MakeSubstitution(std::make_shared<TupleListGetitemEliminator>(), "tuple_list_get_item_eliminator",
  75. {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
  76. tuple_list_get_item_const_eliminator_ =
  77. MakeSubstitution(std::make_shared<TupleListGetitemConstEliminator>(), "tuple_list_get_item_const_eliminator",
  78. {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
  79. tuple_list_set_item_eliminator_ =
  80. MakeSubstitution(std::make_shared<TupleListSetitemEliminator>(), "tuple_list_set_item_eliminator",
  81. {prim::kPrimTupleSetItem, prim::kPrimListSetItem});
  82. tuple_list_get_set_item_eliminator_ =
  83. MakeSubstitution(std::make_shared<TupleListGetSetitemEliminator>(), "tuple_list_get_set_item_eliminator",
  84. {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
  85. tuple_list_get_item_depend_reorder_ =
  86. MakeSubstitution(std::make_shared<TupleListGetitemDependReorder>(), "tuple_list_get_item_depend_reorder",
  87. {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
  88. tuple_list_convert_item_index_to_positive_ = MakeSubstitution(
  89. std::make_shared<TupleListConvertItemIndexToPositive>(), "tuple_list_convert_item_index_to_positive",
  90. {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
  91. make_slice_get_slice_eliminator_ = MakeSubstitution(std::make_shared<MakeSliceSliceGetItemEliminator>(),
  92. "make_slice_get_slice_eliminator", {prim::kPrimSliceGetItem});
  93. tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
  94. cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
  95. reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
  96. transpose_eliminate_ =
  97. MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
  98. reduce_eliminate_ = MakeSubstitution(
  99. std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
  100. {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
  101. partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
  102. same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
  103. mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
  104. prim::kPrimMirrorMiniStep);
  105. mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
  106. "mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
  107. micro_step_allgather_replace_ = MakeSubstitution(std::make_shared<MicroStepAllGatherPass>(),
  108. "micro_step_allgather_replace", prim::kPrimMicroStepAllGather);
  109. virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual_add", prim::kPrimVirtualAdd);
  110. check_bprop_eliminate_ =
  111. MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
  112. reset_defer_inline_ =
  113. MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
  114. depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
  115. all_reduce_const_elim_ =
  116. MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);
  117. real_op_eliminate_ = MakeSubstitution(std::make_shared<RealOpEliminate>(), "real_op_eliminate", prim::kPrimRealInner);
  118. // Environ Item Eliminate
  119. environ_get_eliminate_ =
  120. MakeSubstitution(std::make_shared<EnvironGetEliminater>(), "environ_get_eliminate", prim::kPrimEnvironGet);
  121. environ_get_add_eliminate_ =
  122. MakeSubstitution(std::make_shared<EnvironGetAddEliminater>(), "environ_get_add_eliminate", prim::kPrimEnvironGet);
  123. environ_get_set_eliminate_ =
  124. MakeSubstitution(std::make_shared<EnvironGetSetEliminater>(), "environ_get_set_eliminate", prim::kPrimEnvironGet);
  125. environ_get_depend_swap_ =
  126. MakeSubstitution(std::make_shared<EnvironGetDependSwap>(), "environ_get_depend_swap", prim::kPrimEnvironGet);
  127. environ_add_const_eliminate_ = MakeSubstitution(std::make_shared<EnvironAddConstEliminater>(),
  128. "environ_add_const_eliminate_", prim::kPrimEnvironAdd);
  129. incorporate_environ_get_bypass_recursive_ =
  130. MakeSubstitution(std::make_shared<IncorporateEnvironGet>(true), "incorporate_environ_get", prim::kPrimEnvironGet);
  131. incorporate_environ_get_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvironGetSwitch>(),
  132. "incorporate_environ_get_switch", prim::kPrimEnvironGet);
  133. incorporate_environ_get_ =
  134. MakeSubstitution(std::make_shared<IncorporateEnvironGet>(), "incorporate_environ_get", prim::kPrimEnvironGet);
  135. incorporate_environ_get_switch_layer_ =
  136. MakeSubstitution(std::make_shared<IncorporateEnvironGetSwitchLayer>(), "incorporate_environ_get_switch_layer",
  137. prim::kPrimEnvironGet);
  138. split_environ_get_set_with_tuple_value_ =
  139. MakeSubstitution(std::make_shared<SplitEnvironGetSetWithTupleValue>(), "split_environ_get_set_with_tuple_value",
  140. {prim::kPrimEnvironGet, prim::kPrimEnvironSet});
  141. // Ref eliminate
  142. make_ref_eliminate_ =
  143. MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
  144. get_ref_param_eliminate_ =
  145. MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", {prim::kPrimGetRefValue});
  146. get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
  147. {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
  148. replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
  149. IsValueNode<RefKey>, opt::FORCE_RENORM);
  150. replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
  151. minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
  152. // branch culling
  153. switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
  154. float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
  155. "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
  156. float_environ_get_switch_ =
  157. MakeSubstitution(std::make_shared<FloatEnvironGetSwitch>(), "float_environ_get_switch", prim::kPrimEnvironGet);
  158. exchange_switch_depend_value_ =
  159. MakeSubstitution(std::make_shared<ExchangeSwitchDependValue>(), "exchange_switch_depend_value", prim::kPrimSwitch);
  160. switch_partial_eliminater_ =
  161. MakeSubstitution(std::make_shared<SwitchPartialEliminater>(), "eliminate_switch_partial_", IsCNodeDup);
  162. switch_layer_partial_eliminater_ =
  163. MakeSubstitution(std::make_shared<SwitchLayerPartialEliminater>(), "eliminate_switch_layer_partial_", IsCNodeDup);
  164. // Addn
  165. merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
  166. addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
  167. // AccumulateNV2
  168. accumulaten_eliminater_ =
  169. MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2);
  170. // Accelerated Algorithm
  171. less_batch_normalization_ =
  172. MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization",
  173. {prim::kPrimAdd, prim::kPrimRelu6, prim::kPrimMatMul, prim::kPrimMakeTuple, prim::kPrimMaxPool});
  174. // inline
  175. inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
  176. inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
  177. replace_applicator_ =
  178. MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
  179. specialize_transform_ =
  180. MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
  181. // UpdateState eliminate
  182. updatestate_useless_node_eliminater_ =
  183. MakeSubstitution(std::make_shared<UpdatestateUselessNodeEliminater>(), "updatestate_useless_node_eliminater",
  184. prim::kPrimUpdateState);
  185. updatestate_pure_node_eliminater_ = MakeSubstitution(std::make_shared<UpdatestatePureNodeEliminater>(),
  186. "updatestate_pure_node_eliminater", prim::kPrimUpdateState);
  187. switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
  188. "switch_call_monad_eliminater", IsCNodeDup);
  189. // Load eliminate
  190. load_eliminater_ = MakeSubstitution(std::make_shared<LoadEliminater>(), "load_eliminater", prim::kPrimLoad);
  191. // StopGradient eliminate
  192. stopgrad_eliminater_ =
  193. MakeSubstitution(std::make_shared<StopGradientEliminater>(), "stopgrad_eliminater", prim::kPrimStopGradient);
  194. // Incorporation
  195. incorporate_getitem_set_ =
  196. MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
  197. incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
  198. incorporate_call_switch_ =
  199. MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
  200. // Virtual Dataset
  201. virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
  202. "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
  203. // Virtual Dataset
  204. virtual_output_eliminate_ =
  205. MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
  206. // PipelineSplit
  207. receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
  208. virtual_accu_grad_ =
  209. MakeSubstitution(std::make_shared<VirtualAccuGradEliminater>(), "virtual_accu_grad", prim::kPrimVirtualAccuGrad);
  210. virtual_assign_add_ =
  211. MakeSubstitution(std::make_shared<VirtualAssignAddEliminater>(), "virtual_assign_add", prim::kPrimVirtualAssignAdd);
  212. mirror_micro_step_ =
  213. MakeSubstitution(std::make_shared<MirrorMicroStepEliminater>(), "mirror_micro_step", prim::kPrimMirrorMicroStep);
  214. // Convert
  215. print_tuple_wrapper_ =
  216. MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
  217. // tuple parameter graph transform
  218. call_graph_tuple_transform_ =
  219. MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
  220. // RowTensor Eliminate
  221. row_tensor_eliminate_ = MakeSubstitution(
  222. std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",
  223. {prim::kPrimRowTensorGetIndices, prim::kPrimRowTensorGetValues, prim::kPrimRowTensorGetDenseShape});
  224. // RowTensorAddZerosLike Eliminate
  225. row_tensor_add_zeros_like_ =
  226. MakeSubstitution(std::make_shared<RowTensorAddZerosLike>(), "row_tensor_add_zeros_like", prim::kPrimRowTensorAdd);
  227. // SparseTensor Eliminate
  228. sparse_tensor_eliminate_ = MakeSubstitution(
  229. std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
  230. {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
  231. // Value_Based Eliminate
  232. value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
  233. {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
  234. // switch defer inline
  235. switch_defer_inline_ =
  236. MakeSubstitution(std::make_shared<SwitchDeferInline>(), "switch_defer_inline", prim::kPrimSwitch);
  237. // switch_layer defer inline
  238. switch_layer_defer_inline_ =
  239. MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
  240. // recompute
  241. set_cell_output_no_recompute_ = MakeSubstitution(std::make_shared<SetCellOutputNoRecompute>(),
  242. "set_cell_output_no_recompute", IsValueNode<FuncGraph>);
  243. }
  244. ResolveIRPassLib::ResolveIRPassLib() {
  245. // In resolver_, some patterns have priority over others.
  246. resolver_ = MakeSubstitution(std::make_shared<Resolver>(), "getattr_resolve",
  247. {prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
  248. }
  249. InferenceOptPrepareLib::InferenceOptPrepareLib() {
  250. grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
  251. }
  252. } // namespace irpass
  253. } // namespace opt
  254. } // namespace mindspore