You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.h 11 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H
  17. #define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H
  18. #include <iostream>
  19. #include <string>
  20. #include <vector>
  21. #include <set>
  22. #include <tuple>
  23. #include <utility>
  24. #include <memory>
  25. #include "ir/anf.h"
  26. #include "ir/dtype.h"
  27. #include "ir/base.h"
  28. #include "ir/primitive.h"
  29. #include "device/device_address.h"
  30. #include "kernel/kernel.h"
  31. #include "kernel/kernel_build_info.h"
  32. #include "operator/ops.h"
  33. #include "utils/contract.h"
  34. #include "session/kernel_graph.h"
  35. namespace mindspore {
  36. namespace session {
  37. using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
  38. using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
  39. class AnfRuntimeAlgorithm {
  40. public:
  41. // get input_anf_node's real kernel by recurse
  42. static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
  43. static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
  44. bool visit_nop_node = false,
  45. const std::vector<PrimitivePtr> &return_types = {
  46. prim::kPrimMakeTuple});
  47. static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
  48. const std::vector<PrimitivePtr> &return_types = {});
  49. // get cnode primitive
  50. static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
  51. static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
  52. static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
  53. // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
  54. static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
  55. // get kernel_name of anf node
  56. static std::string GetCNodeName(const AnfNodePtr &node);
  57. // get detail info of anf node
  58. static std::string GetNodeDebugString(const AnfNodePtr &node);
  59. // get attr of anf node
  60. template <typename T>
  61. static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
  62. MS_EXCEPTION_IF_NULL(node);
  63. if (!node->isa<CNode>()) {
  64. std::string node_debug_log = node->DebugString();
  65. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
  66. }
  67. auto primitive = GetCNodePrimitive(node);
  68. MS_EXCEPTION_IF_NULL(primitive);
  69. return GetValue<T>(primitive->GetAttr(key));
  70. }
  71. static bool IsTupleOutput(const AnfNodePtr &anf);
  72. // set attr of anf node
  73. static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
  74. // set attr of key from 'from' node to 'to' node
  75. static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
  76. // set a new key for attr from 'from' node to 'to' node
  77. static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  78. const AnfNodePtr &to);
  79. // set all attrs from 'from' node to 'to' node
  80. static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
  81. // check whether a cnode has the specified attr.
  82. static bool HasNodeAttr(const std::string &key, const CNodePtr &node);
  83. // delete attr of anf node
  84. static void EraseNodeAttr(const std::string &key, AnfNodePtr node);
  85. // get the num of input real_kernel(which can be build and run in device)
  86. static size_t GetInputTensorNum(const AnfNodePtr &node);
  87. // get the num of output real_kernel(which can be build and run in device)
  88. static size_t GetOutputTensorNum(const AnfNodePtr &node);
  89. // get output format select of anf node
  90. static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
  91. // get input format select of anf node
  92. static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
  93. // get prev node output width output index
  94. static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
  95. // get output format from prev node,input_index is the input index of current node related to prev node
  96. static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
  97. // get reshape_type of from the output of input node.
  98. static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
  99. // get output shapes inferred by ME from input nodes.
  100. static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
  101. // get input shapes inferred by ME from input nodes.
  102. static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
  103. // get output shapes which will built and run in device
  104. static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
  105. // get input shapes which will built and run in device
  106. static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
  107. // Get Input Padding Axis
  108. static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
  109. // Get Output Padding Axis
  110. static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
  111. // get output data type inferred by ME of anf node
  112. static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
  113. // get output original data type from prev node,input_index is the input index of current node related to prev node
  114. static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
  115. // get output select data type of anf node
  116. static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx);
  117. // get input select data type of anf node
  118. static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
  119. // get output select data type from prev node,input_index is the input index of current node related to prev node
  120. static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
  121. // get output device addr of anf_node
  122. static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx);
  123. // get mutable output device addr of anf_node
  124. static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx);
  125. // check whether output addr is exist or not
  126. static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
  127. // get address from prev node,input_index is the input index of current node related to prev node
  128. static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx);
  129. static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx);
  130. // set output device addr of anf_node
  131. static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
  132. // set workspace device addr of anf_node
  133. static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
  134. // get workspace device addr of anf_node
  135. static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
  136. // set infer shapes and types of anf node
  137. static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  138. const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
  139. static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
  140. // get op pattern of the node
  141. static kernel::OpPattern GetOpPattern(const AnfNodePtr &node);
  142. // get KernelBuildType of node ,such as ATT,RT,FWK and so on
  143. static KernelType GetKernelType(const AnfNodePtr &node);
  144. // get processor type:AICORE,AICPU...
  145. static kernel::Processor GetProcessor(const AnfNodePtr &node);
  146. // get fusion type:AICORE,AICPU...
  147. static kernel::FusionType GetFusionType(const AnfNodePtr &node);
  148. // set select kernel_build_info
  149. static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node);
  150. // get select kernel_build_info
  151. static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node);
  152. // get kernelMode
  153. static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node);
  154. // set kernel mod
  155. static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node);
  156. // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
  157. static bool IsRealKernel(const AnfNodePtr &node);
  158. // checkout whether the anf node is a real kernel that is a cnode and can run on device
  159. static bool IsRealCNodeKernel(const AnfNodePtr &node);
  160. // check parameter is weight or data
  161. static bool IsParameterWeight(const ParameterPtr &node);
  162. // set stream id of kernel,which will be set in stream assign and be used in stream generate
  163. static void SetStreamId(uint32_t stream_id, AnfNode *node);
  164. // get stream id
  165. static uint32_t GetStreamId(const AnfNodePtr &node);
  166. // set stream distinction label to distinguish different ops in different streams
  167. static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node);
  168. // get stream distinction label
  169. static uint32_t GetStreamDistinctionLabel(const AnfNode *node);
  170. // set graph id
  171. static void SetGraphId(uint32_t graph_id, AnfNode *node);
  172. // get graph id
  173. static uint32_t GetGraphId(const AnfNode *node);
  174. static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
  175. // charge if the node's output is a feature map output
  176. static bool IsFeatureMapOutput(const AnfNodePtr &node);
  177. // charge if the node's input is from a feature map output
  178. static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
  179. // get real input index for some tbe ops which input order is different between me and tbe impl
  180. static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
  181. static bool IsCommunicationOp(const AnfNodePtr &node);
  182. static bool IsGetNext(const NotNull<AnfNodePtr> &node);
  183. static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
  184. static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
  185. static bool IsSwitchCall(const CNodePtr &call_node);
  186. };
  187. } // namespace session
  188. using AnfAlgo = session::AnfRuntimeAlgorithm;
  189. } // namespace mindspore
  190. #endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H