You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.h 12 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H
  17. #define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H
  18. #include <iostream>
  19. #include <string>
  20. #include <vector>
  21. #include <set>
  22. #include <tuple>
  23. #include <utility>
  24. #include <memory>
  25. #include "ir/anf.h"
  26. #include "ir/dtype.h"
  27. #include "ir/base.h"
  28. #include "ir/primitive.h"
  29. #include "device/device_address.h"
  30. #include "kernel/kernel.h"
  31. #include "kernel/kernel_build_info.h"
  32. #include "operator/ops.h"
  33. #include "utils/contract.h"
  34. #include "session/kernel_graph.h"
  35. namespace mindspore {
  36. namespace session {
  37. using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
  38. using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
  39. class AnfRuntimeAlgorithm {
  40. public:
  41. // get input_anf_node's real kernel by recurse
  42. static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
  43. static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
  44. bool visit_nop_node = false,
  45. const std::vector<PrimitivePtr> &return_types = {
  46. prim::kPrimMakeTuple});
  47. static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
  48. const std::vector<PrimitivePtr> &return_types = {});
  49. // get cnode primitive
  50. static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
  51. static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
  52. static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
  53. // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
  54. static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
  55. // get cnode primitive
  56. static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node);
  57. // get kernel_name of anf node
  58. static std::string GetCNodeName(const AnfNodePtr &node);
  59. // get detail info of anf node
  60. static std::string GetNodeDebugString(const AnfNodePtr &node);
  61. // get attr of anf node
  62. template <typename T>
  63. static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
  64. MS_EXCEPTION_IF_NULL(node);
  65. if (!node->isa<CNode>()) {
  66. std::string node_debug_log = node->DebugString();
  67. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
  68. }
  69. auto primitive = GetCNodePrimitive(node);
  70. MS_EXCEPTION_IF_NULL(primitive);
  71. return GetValue<T>(primitive->GetAttr(key));
  72. }
  73. static bool IsTupleOutput(const AnfNodePtr &anf);
  74. // set attr of anf node
  75. static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
  76. // set attr of key from 'from' node to 'to' node
  77. static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
  78. // set a new key for attr from 'from' node to 'to' node
  79. static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  80. const AnfNodePtr &to);
  81. // set all attrs from 'from' node to 'to' node
  82. static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
  83. // check whether a cnode has the specified attr.
  84. static bool HasNodeAttr(const std::string &key, const CNodePtr &node);
  85. // delete attr of anf node
  86. static void EraseNodeAttr(const std::string &key, AnfNodePtr node);
  87. // get the num of input real_kernel(which can be build and run in device)
  88. static size_t GetInputTensorNum(const AnfNodePtr &node);
  89. // get the num of output real_kernel(which can be build and run in device)
  90. static size_t GetOutputTensorNum(const AnfNodePtr &node);
  91. // get output format select of anf node
  92. static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
  93. // get input format select of anf node
  94. static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
  95. // get prev node output width output index
  96. static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
  97. // get output format from prev node,input_index is the input index of current node related to prev node
  98. static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
  99. // get reshape_type of from the output of input node.
  100. static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
  101. // get output shapes inferred by ME from input nodes.
  102. static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
  103. // get input shapes inferred by ME from input nodes.
  104. static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
  105. // get output shapes which will built and run in device
  106. static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
  107. // get input shapes which will built and run in device
  108. static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
  109. // Get Input Padding Axis
  110. static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
  111. // Get Output Padding Axis
  112. static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
  113. // get output data type inferred by ME of anf node
  114. static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
  115. // get output original data type from prev node,input_index is the input index of current node related to prev node
  116. static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
  117. // get output select data type of anf node
  118. static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx);
  119. // get input select data type of anf node
  120. static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
  121. // get output select data type from prev node,input_index is the input index of current node related to prev node
  122. static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
  123. // get output device addr of anf_node
  124. static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
  125. // get mutable output device addr of anf_node
  126. static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
  127. // check whether output addr is exist or not
  128. static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
  129. // get address from prev node,input_index is the input index of current node related to prev node
  130. static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx,
  131. bool visit_nop_node = true);
  132. static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
  133. bool visit_nop_node = true);
  134. // set output device addr of anf_node
  135. static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
  136. // set workspace device addr of anf_node
  137. static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
  138. // get workspace device addr of anf_node
  139. static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
  140. // set infer shapes and types of anf node
  141. static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  142. const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
  143. static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
  144. // get op pattern of the node
  145. static kernel::OpPattern GetOpPattern(const AnfNodePtr &node);
  146. // get KernelBuildType of node ,such as ATT,RT,FWK and so on
  147. static KernelType GetKernelType(const AnfNodePtr &node);
  148. // get processor type:AICORE,AICPU...
  149. static kernel::Processor GetProcessor(const AnfNodePtr &node);
  150. // get fusion type:AICORE,AICPU...
  151. static kernel::FusionType GetFusionType(const AnfNodePtr &node);
  152. // set select kernel_build_info
  153. static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node);
  154. // get select kernel_build_info
  155. static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node);
  156. // get kernelMode
  157. static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node);
  158. // set kernel mod
  159. static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node);
  160. // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
  161. static bool IsRealKernel(const AnfNodePtr &node);
  162. // checkout whether the anf node is a real kernel that is a cnode and can run on device
  163. static bool IsRealCNodeKernel(const AnfNodePtr &node);
  164. // checkout whether the anf node is a graph kernel.
  165. static bool IsGraphKernel(const AnfNodePtr &node);
  166. // check parameter is weight or data
  167. static bool IsParameterWeight(const ParameterPtr &node);
  168. // set stream id of kernel,which will be set in stream assign and be used in stream generate
  169. static void SetStreamId(uint32_t stream_id, AnfNode *node);
  170. // get stream id
  171. static uint32_t GetStreamId(const AnfNodePtr &node);
  172. // set stream distinction label to distinguish different ops in different streams
  173. static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node);
  174. // get stream distinction label
  175. static uint32_t GetStreamDistinctionLabel(const AnfNode *node);
  176. // set graph id
  177. static void SetGraphId(uint32_t graph_id, AnfNode *node);
  178. // get graph id
  179. static uint32_t GetGraphId(const AnfNode *node);
  180. static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
  181. // charge if the node's output is a feature map output
  182. static bool IsFeatureMapOutput(const AnfNodePtr &node);
  183. // charge if the node's input is from a feature map output
  184. static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
  185. // get real input index for some tbe ops which input order is different between me and tbe impl
  186. static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
  187. static bool IsCommunicationOp(const AnfNodePtr &node);
  188. static bool IsGetNext(const NotNull<AnfNodePtr> &node);
  189. static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
  190. static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
  191. static bool IsSwitchCall(const CNodePtr &call_node);
  192. static bool IsScalarInput(const CNodePtr &cnode, size_t index);
  193. static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
  194. static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
  195. static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph);
  196. // get fix output precision of cnode.
  197. static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
  198. // get fix output precision from prev node, input_idx is the input index of current node related to prev node.
  199. static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
  200. };
  201. } // namespace session
  202. using AnfAlgo = session::AnfRuntimeAlgorithm;
  203. } // namespace mindspore
  204. #endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H