You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_
  17. #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_
  18. #include <vector>
  19. #include <memory>
  20. #include <utility>
  21. #include <string>
  22. #include <set>
  23. #include "utils/hash_set.h"
  24. #include "ir/func_graph.h"
  25. #include "backend/session/kernel_graph.h"
  26. #include "utils/ms_utils.h"
  27. #include "backend/optimizer/common/pattern_engine.h"
  28. namespace mindspore {
  29. namespace opt {
  30. constexpr size_t kTransOpInputTensorNum = 1;
  31. constexpr size_t kCastInputTensorNum = 1;
  32. constexpr size_t kDependInputTensorNum = 2;
  33. constexpr size_t kReluInputTensorNum = 1;
  34. constexpr size_t kReluGradInputTensorNum = 2;
  35. constexpr size_t kAddInputTensorNum = 2;
  36. constexpr size_t kTupleGetItemInputTensorNum = 2;
  37. constexpr size_t kConvInputTensorNum = 2;
  38. constexpr size_t kRealDivInputTensorNum = 2;
  39. constexpr size_t kSqrtInputTensorNum = 1;
  40. constexpr size_t kMatMulInputTensorNum = 2;
  41. constexpr size_t kMulInputTensorNum = 2;
  42. constexpr size_t kSubInputTensorNum = 2;
  43. constexpr size_t kAssignSubInputTensorNum = 2;
  44. constexpr size_t kDropoutInputTensorNum = 1;
  45. constexpr size_t kAssignInputTensorNum = 2;
  46. constexpr size_t kGradIndex = 3;
  47. constexpr size_t kAddNInputNum = 2;
  48. constexpr size_t kConvBn1OutputNum = 3;
  49. constexpr size_t kBn2ReluOutputNum = 4;
  50. constexpr size_t kBnInputTensorNum = 5;
  51. constexpr size_t kBnOutputNum = 5;
  52. constexpr size_t kBN1OutputNum = 2;
  53. constexpr size_t kBN2OutputNum = 3;
  54. constexpr size_t kBN3OutputNum = 1;
  55. constexpr size_t kBNGradInputTensorNum = 5;
  56. constexpr size_t kBNGradOutputNum = 3;
  57. constexpr size_t kBNGrad1OutputNum = 3;
  58. constexpr size_t kBNGrad2OutputNum = 5;
  59. constexpr size_t kBNGrad3OutputNum = 1;
  60. constexpr size_t kBNTrainingReduceOutputNum = 2;
  61. constexpr size_t kBNTrainingUpdateOutputNum = 5;
  62. constexpr size_t kBNTrainingUpdateV2OutputNum = 3;
  63. constexpr size_t kBNTrainingUpdateV3OutputNum = 5;
  64. constexpr size_t kBNTrainingUpdateGradOutputNum = 2;
  65. constexpr size_t kSingleOutputNum = 1;
  66. constexpr size_t kSumNodeInputTensorNum = 1;
  67. constexpr size_t kSquareNodeInputTensorNum = 1;
  68. constexpr size_t kSquareSumv2OutputNum = 2;
  69. constexpr size_t kMinimumInputTensorNum = 2;
  70. constexpr size_t kLambNextMVWithDecayInputNum = 7;
  71. constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5;
  72. constexpr size_t kLambNextMVWithDecayOutputNum = 4;
  73. constexpr size_t kLambNextMVWithDecayV1OutputNum = 4;
  74. constexpr size_t kLambNextRightOutputNum = 2;
  75. constexpr size_t kLambUpdateWithLrV2InputNum = 8;
  76. constexpr size_t kLambNextMVRuleInputNum = 14;
  77. constexpr size_t kLambNextMVRuleOutputNum = 4;
  78. constexpr size_t kBackendReshapeInputTensorNum = 1;
  79. constexpr size_t kBackendTransposeInputTensorNum = 1;
  80. constexpr size_t kAdamApplyOneWithDecayOutputNum = 3;
  81. constexpr size_t kLayerNormBetaGammaBackpropInputTensorNum = 4;
  82. constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2;
  83. constexpr size_t kLayerNormBetaGammaBackpropV2InputTensorNum = 2;
  84. constexpr size_t kLayerNormXBackpropOutputNum = 4;
  85. constexpr size_t kLayerNormXBackpropV2OutputNum = 2;
  86. constexpr size_t kLayerNormGradInputTensorNum = 5;
  87. constexpr size_t kAdamApplyOneOutputNum = 3;
  88. constexpr size_t kApplyMomentumInputTensorNum = 5;
  89. constexpr size_t kBiasAddInputTensorNum = 2;
  90. constexpr size_t kTopkInputTensorNum = 2;
  91. constexpr size_t kLarsV2InputTensorNum = 4;
  92. constexpr size_t kFusedMulApplyMomentumOutputNum = 2;
  93. constexpr size_t kSplitInputTensorNum = 1;
  94. constexpr size_t kGatherV2DynInputTensorNum = 3;
  95. constexpr size_t kUnsortedSegmentSumInputTensorNum = 2;
  96. constexpr size_t kSoftmaxCrossEntropyWithLogitsOutputNum = 2;
  97. constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum = 2;
  98. constexpr size_t kOneHotOutputNum = 1;
  99. constexpr size_t kOneHotInputTensorNum = 4;
  100. enum FusedBatchNormInput {
  101. kX = 1,
  102. kVariance = 5,
  103. };
  104. enum FusedBatchNormOutput {
  105. kY = 0,
  106. kRunningMean,
  107. kRunningVariance,
  108. kSaveMean,
  109. kSaveInvVariance,
  110. };
  111. enum ConvBn1Output {
  112. kData = 0,
  113. kVarPart,
  114. kMean,
  115. };
  116. std::vector<int64_t> Convert2Int(const std::vector<size_t> &v);
  117. std::vector<int64_t> Convert2Long(const std::vector<size_t> &v);
  118. // check whether node depends on either of nodes or not
  119. bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes);
  120. bool UnVisited(const BaseRef &n);
  121. bool Visited(const BaseRef &n);
  122. CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
  123. const std::vector<AnfNodePtr> &orig_nodes);
  124. CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes);
  125. // check if the input node is CNode, then check it's input_size, return CNodePtr if check success.
  126. CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size);
  127. void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_num);
  128. bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y);
  129. const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
  130. void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode,
  131. std::vector<AnfNodePtr> *conv_bn1_outputs);
  132. void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &fused_bn1_outputs,
  133. const CNodePtr &bn_node, std::vector<AnfNodePtr> *fused_bn2_outputs);
  134. void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input,
  135. const std::vector<AnfNodePtr> &fused_bn1_outputs,
  136. const std::vector<AnfNodePtr> &fused_bn2_outputs, const CNodePtr &bn_node,
  137. std::vector<AnfNodePtr> *fused_bn3_outputs);
  138. void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num,
  139. std::vector<AnfNodePtr> *outputs);
  140. tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
  141. size_t data_length);
  142. tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple);
  143. bool IsAllNopNode(const session::KernelGraph *const graph);
  144. bool IsNopNode(const AnfNodePtr &node);
  145. void HideNopNode(session::KernelGraph *const graph);
  146. void RemoveNopNode(session::KernelGraph *const graph);
  147. CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx);
  148. ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape,
  149. bool to_tensor = false);
  150. bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
  151. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
  152. const AnfNodePtr &node);
  153. size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node);
  154. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
  155. const AnfNodePtr &node,
  156. size_t output_index);
  157. bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
  158. void ConstInputToAttr(const CNodePtr &cnode, const mindspore::HashSet<size_t> &input_attrs);
  159. bool AnfEqual(const BaseRef &a, const BaseRef &b);
  160. bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
  161. AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  162. bool multigraph = false);
  163. // Check var_node in two equivs is the same node
  164. bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node);
  165. // Get anf_node from equiv by var_node
  166. AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node);
  167. // Compare tuple getitem's index, return bool[n1's index < n2's index]
  168. bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2);
  169. // Get attr which is bool from cnode
  170. bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name);
  171. // Check node's data type is in supported data type set
  172. bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set);
  173. // Create a new value node of func graph,not kernel graph
  174. ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
  175. // Transfer depend or updatestate to the new node
  176. void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
  177. AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
  178. // Generate kernel build info for created kernel
  179. kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list);
  180. // Get used number of node's each output
  181. std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
  182. // Get total used number of node's output
  183. int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
  184. // Get custom operator attr input indexes
  185. void GetCustomOpAttrIndex(const PrimitivePtr &primitive, mindspore::HashSet<size_t> *indexes);
  186. } // namespace opt
  187. } // namespace mindspore
  188. #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_