You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
  17. #define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
  18. #include <vector>
  19. #include <memory>
  20. #include <string>
  21. #include <unordered_set>
  22. #include "ir/func_graph.h"
  23. #include "session/kernel_graph.h"
  24. #include "common/utils.h"
  25. namespace mindspore {
  26. namespace opt {
  27. constexpr size_t kTransOpInputNum = 2;
  28. constexpr size_t kCastInputNum = 2;
  29. constexpr size_t kDependInputNum = 3;
  30. constexpr size_t kReluInputNum = 2;
  31. constexpr size_t kReluGradInputNum = 3;
  32. constexpr size_t kAddInputNum = 3;
  33. constexpr size_t kAddNInputNum = 3;
  34. constexpr size_t kTupleGetitemInputNum = 3;
  35. constexpr size_t kConvInputNum = 3;
  36. constexpr size_t kRealDivInputNum = 3;
  37. constexpr size_t kSqrtInputNum = 2;
  38. constexpr size_t kMulInputNum = 3;
  39. constexpr size_t kRsqrtInputNum = 2;
  40. constexpr size_t kSubInputNum = 3;
  41. constexpr size_t kAssignSubInputNum = 3;
  42. constexpr size_t kConvBn1OutputNum = 3;
  43. constexpr size_t kBn2ReluOutputNum = 4;
  44. constexpr size_t kBnInputNum = 6;
  45. constexpr size_t kBnOutputNum = 5;
  46. constexpr size_t kBN1OutputNum = 2;
  47. constexpr size_t kBN2OutputNum = 3;
  48. constexpr size_t kBN3OutputNum = 1;
  49. constexpr size_t kBNGradInputNum = 6;
  50. constexpr size_t kBNGradOutputNum = 3;
  51. constexpr size_t kBNGrad1OutputNum = 3;
  52. constexpr size_t kBNGrad2OutputNum = 5;
  53. constexpr size_t kBNGrad3OutputNum = 1;
  54. constexpr size_t kBNTrainingReduceOutputNum = 2;
  55. constexpr size_t kBNTrainingUpdateOutputNum = 5;
  56. constexpr size_t kBNTrainingUpdateGradOutputNum = 2;
  57. constexpr size_t kSingleOutputNum = 1;
  58. constexpr size_t kSumNodeInputNum = 2;
  59. constexpr size_t kSquareNodeInputNum = 2;
  60. constexpr size_t kSquareSumv2OutputNum = 2;
  61. constexpr size_t kMinimumInputNum = 3;
  62. constexpr size_t kLambNextMVWithDecayInputNum = 7;
  63. constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5;
  64. constexpr size_t kLambNextMVWithDecayOutputNum = 4;
  65. constexpr size_t kLambNextMVWithDecayV1OutputNum = 4;
  66. constexpr size_t kLambNextRightOutputNum = 2;
  67. constexpr size_t kLambUpdateWithLrV2InputNum = 8;
  68. constexpr size_t kLambNextMVRuleInputNum = 14;
  69. constexpr size_t kLambNextMVRuleOutputNum = 4;
  70. constexpr size_t kBackendReshapeInputNum = 2;
  71. constexpr size_t kBackendTransposeInputNum = 2;
  72. constexpr size_t kAdamApplyOneWithDecayOutputNum = 3;
  73. constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5;
  74. constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2;
  75. constexpr size_t kLayerNormGradInputNum = 6;
  76. constexpr size_t kAdamApplyOneOutputNum = 3;
  77. constexpr size_t kBackendTransDataInputNum = 2;
  78. constexpr size_t kApplyMomentumInputNum = 6;
  79. constexpr size_t kBiasAddInputNum = 3;
  80. constexpr size_t kTopkInputNum = 3;
  81. enum FusedBatchNormInput {
  82. kX = 1,
  83. kVariance = 5,
  84. };
  85. enum FusedBatchNormOutput {
  86. kY = 0,
  87. kRunningMean,
  88. kRunningVariance,
  89. kSaveMean,
  90. kSaveInvVariance,
  91. };
  92. enum ConvBn1Output {
  93. kData = 0,
  94. kVarPart,
  95. kMean,
  96. };
  97. std::vector<int> Convert2Int(const std::vector<size_t> &v);
  98. bool UnVisited(const BaseRef &n);
  99. bool Visited(const BaseRef &n);
  100. // check if the input node is CNode, then check it's input_size, if meet condition above, return true, otherwise return
  101. // false. cnode can only be used when return true.
  102. bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode);
  103. // check if the input node is CNode, then check it's input_size, return CNodePtr if check success.
  104. CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size);
  105. void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size);
  106. bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y);
  107. const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
  108. void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode,
  109. std::vector<AnfNodePtr> *conv_bn1_outputs);
  110. void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &fused_bn1_outputs,
  111. const CNodePtr &bn_node, std::vector<AnfNodePtr> *fused_bn2_outputs);
  112. void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input,
  113. const std::vector<AnfNodePtr> &fused_bn1_outputs,
  114. const std::vector<AnfNodePtr> &fused_bn2_outputs, const CNodePtr &bn_node,
  115. std::vector<AnfNodePtr> *fused_bn3_outputs);
  116. void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num,
  117. std::vector<AnfNodePtr> *outputs);
  118. tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
  119. size_t data_length);
  120. tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple);
  121. bool IsNopNode(const AnfNodePtr &node);
  122. void HideNopNode(session::KernelGraph *const graph);
  123. void RemoveNopNode(session::KernelGraph *const graph);
  124. AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx);
  125. bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
  126. void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);
  127. } // namespace opt
  128. } // namespace mindspore
  129. #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_