You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.h 12 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H
  17. #define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H
  18. #include <iostream>
  19. #include <string>
  20. #include <vector>
  21. #include <set>
  22. #include <tuple>
  23. #include <utility>
  24. #include <memory>
  25. #include "ir/anf.h"
  26. #include "ir/dtype.h"
  27. #include "ir/base.h"
  28. #include "ir/primitive.h"
  29. #include "device/device_address.h"
  30. #include "kernel/kernel.h"
  31. #include "kernel/kernel_build_info.h"
  32. #include "operator/ops.h"
  33. #include "utils/contract.h"
  34. #include "session/kernel_graph.h"
  35. namespace mindspore {
  36. namespace session {
  37. using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
  38. using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
  39. class AnfRuntimeAlgorithm {
  40. public:
  41. // get input_anf_node's real kernel by recurse
  42. static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
  43. static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
  44. bool visit_nop_node = false,
  45. const std::vector<PrimitivePtr> &return_types = {
  46. prim::kPrimMakeTuple});
  47. static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
  48. const std::vector<PrimitivePtr> &return_types = {});
  49. // get cnode primitive
  50. static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
  51. static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
  52. static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
  53. // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
  54. static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
  55. // get cnode primitive
  56. static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node);
  57. // get kernel_name of anf node
  58. static std::string GetCNodeName(const AnfNodePtr &node);
  59. // get detail info of anf node
  60. static std::string GetNodeDebugString(const AnfNodePtr &node);
  61. // get attr of anf node
  62. template <typename T>
  63. static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
  64. MS_EXCEPTION_IF_NULL(node);
  65. if (!node->isa<CNode>()) {
  66. std::string node_debug_log = node->DebugString();
  67. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
  68. }
  69. // single op cnode.
  70. if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) {
  71. return GetValue<T>(primitive->GetAttr(key));
  72. }
  73. // graph kernel cnode.
  74. auto fg = GetCNodeFuncGraphPtr(node);
  75. MS_EXCEPTION_IF_NULL(fg);
  76. return GetValue<T>(fg->get_attr(key));
  77. }
  78. static bool IsTupleOutput(const AnfNodePtr &anf);
  79. // set attr of anf node
  80. static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
  81. // set attr of key from 'from' node to 'to' node
  82. static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
  83. // set a new key for attr from 'from' node to 'to' node
  84. static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  85. const AnfNodePtr &to);
  86. // set all attrs from 'from' node to 'to' node
  87. static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
  88. // check whether a cnode has the specified attr.
  89. static bool HasNodeAttr(const std::string &key, const CNodePtr &node);
  90. // delete attr of anf node
  91. static void EraseNodeAttr(const std::string &key, AnfNodePtr node);
  92. // get the num of input real_kernel(which can be build and run in device)
  93. static size_t GetInputTensorNum(const AnfNodePtr &node);
  94. // get the num of output real_kernel(which can be build and run in device)
  95. static size_t GetOutputTensorNum(const AnfNodePtr &node);
  96. // get output format select of anf node
  97. static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
  98. // get input format select of anf node
  99. static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
  100. // get prev node output width output index
  101. static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
  102. // get output format from prev node,input_index is the input index of current node related to prev node
  103. static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
  104. // get reshape_type of from the output of input node.
  105. static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
  106. // get output shapes inferred by ME from input nodes.
  107. static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
  108. // get input shapes inferred by ME from input nodes.
  109. static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
  110. // get output shapes which will built and run in device
  111. static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
  112. // get input shapes which will built and run in device
  113. static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
  114. // Get Input Padding Axis
  115. static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
  116. // Get Output Padding Axis
  117. static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
  118. // get output data type inferred by ME of anf node
  119. static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
  120. // get output original data type from prev node,input_index is the input index of current node related to prev node
  121. static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
  122. // get output select data type of anf node
  123. static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx);
  124. // get input select data type of anf node
  125. static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
  126. // get output select data type from prev node,input_index is the input index of current node related to prev node
  127. static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
  128. // get output device addr of anf_node
  129. static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
  130. // get mutable output device addr of anf_node
  131. static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
  132. // check whether output addr is exist or not
  133. static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
  134. // get address from prev node,input_index is the input index of current node related to prev node
  135. static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx,
  136. bool visit_nop_node = true);
  137. static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
  138. bool visit_nop_node = true);
  139. // set output device addr of anf_node
  140. static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
  141. // set workspace device addr of anf_node
  142. static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
  143. // get workspace device addr of anf_node
  144. static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
  145. // set infer shapes and types of anf node
  146. static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  147. const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
  148. static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
  149. // get op pattern of the node
  150. static kernel::OpPattern GetOpPattern(const AnfNodePtr &node);
  151. // get KernelBuildType of node ,such as ATT,RT,FWK and so on
  152. static KernelType GetKernelType(const AnfNodePtr &node);
  153. // get processor type:AICORE,AICPU...
  154. static kernel::Processor GetProcessor(const AnfNodePtr &node);
  155. // get fusion type:AICORE,AICPU...
  156. static kernel::FusionType GetFusionType(const AnfNodePtr &node);
  157. // set select kernel_build_info
  158. static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node);
  159. // get select kernel_build_info
  160. static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node);
  161. // get kernelMode
  162. static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node);
  163. // set kernel mod
  164. static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node);
  165. // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
  166. static bool IsRealKernel(const AnfNodePtr &node);
  167. // checkout whether the anf node is a real kernel that is a cnode and can run on device
  168. static bool IsRealCNodeKernel(const AnfNodePtr &node);
  169. // checkout whether the anf node is a graph kernel.
  170. static bool IsGraphKernel(const AnfNodePtr &node);
  171. // check parameter is weight or data
  172. static bool IsParameterWeight(const ParameterPtr &node);
  173. // set stream id of kernel,which will be set in stream assign and be used in stream generate
  174. static void SetStreamId(uint32_t stream_id, AnfNode *node);
  175. // get stream id
  176. static uint32_t GetStreamId(const AnfNodePtr &node);
  177. // set stream distinction label to distinguish different ops in different streams
  178. static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node);
  179. // get stream distinction label
  180. static uint32_t GetStreamDistinctionLabel(const AnfNode *node);
  181. // set graph id
  182. static void SetGraphId(uint32_t graph_id, AnfNode *node);
  183. // get graph id
  184. static uint32_t GetGraphId(const AnfNode *node);
  185. static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
  186. // charge if the node's output is a feature map output
  187. static bool IsFeatureMapOutput(const AnfNodePtr &node);
  188. // charge if the node's input is from a feature map output
  189. static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
  190. // get real input index for some tbe ops which input order is different between me and tbe impl
  191. static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
  192. static bool IsCommunicationOp(const AnfNodePtr &node);
  193. static bool IsGetNext(const NotNull<AnfNodePtr> &node);
  194. static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
  195. static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
  196. static bool IsSwitchCall(const CNodePtr &call_node);
  197. static bool IsScalarInput(const CNodePtr &cnode, size_t index);
  198. static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
  199. static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
  200. // get fix output precision of cnode.
  201. static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
  202. // get fix output precision from prev node, input_idx is the input index of current node related to prev node.
  203. static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
  204. };
  205. } // namespace session
  206. using AnfAlgo = session::AnfRuntimeAlgorithm;
  207. } // namespace mindspore
  208. #endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H