You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.h 11 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H
  17. #define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H
  18. #include <iostream>
  19. #include <string>
  20. #include <vector>
  21. #include <set>
  22. #include <tuple>
  23. #include <utility>
  24. #include <memory>
  25. #include "ir/anf.h"
  26. #include "ir/dtype.h"
  27. #include "ir/base.h"
  28. #include "ir/primitive.h"
  29. #include "device/device_address.h"
  30. #include "kernel/kernel.h"
  31. #include "kernel/kernel_build_info.h"
  32. #include "operator/ops.h"
  33. #include "utils/contract.h"
  34. namespace mindspore {
  35. namespace session {
  36. using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
  37. using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
  38. class AnfRuntimeAlgorithm {
  39. public:
  40. // get input_anf_node's real kernel by recurse
  41. static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
  42. static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
  43. bool visit_nop_node = false,
  44. const std::vector<PrimitivePtr> &return_types = {
  45. prim::kPrimMakeTuple});
  46. static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
  47. const std::vector<PrimitivePtr> &return_types = {});
  48. // get cnode primitive
  49. static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
  50. static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
  51. static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
  52. // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
  53. static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
  54. // get kernel_name of anf node
  55. static std::string GetCNodeName(const AnfNodePtr &node);
  56. // get detail info of anf node
  57. static std::string GetNodeDebugString(const AnfNodePtr &node);
  58. // get attr of anf node
  59. template <typename T>
  60. static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
  61. MS_EXCEPTION_IF_NULL(node);
  62. if (!node->isa<CNode>()) {
  63. std::string node_debug_log = node->DebugString();
  64. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
  65. }
  66. auto primitive = GetCNodePrimitive(node);
  67. MS_EXCEPTION_IF_NULL(primitive);
  68. return GetValue<T>(primitive->GetAttr(key));
  69. }
  70. static bool IsTupleOutput(const AnfNodePtr &anf);
  71. // set attr of anf node
  72. static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
  73. // set attr of key from 'from' node to 'to' node
  74. static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
  75. // set a new key for attr from 'from' node to 'to' node
  76. static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  77. const AnfNodePtr &to);
  78. // set all attrs from 'from' node to 'to' node
  79. static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
  80. // check whether a cnode has the specified attr.
  81. static bool HasNodeAttr(const std::string &key, const AnfNodePtr &node);
  82. // delete attr of anf node
  83. static void EraseNodeAttr(const std::string &key, AnfNodePtr node);
  84. // get the num of input real_kernel(which can be build and run in device)
  85. static size_t GetInputTensorNum(const AnfNodePtr &node);
  86. // get the num of output real_kernel(which can be build and run in device)
  87. static size_t GetOutputTensorNum(const AnfNodePtr &node);
  88. // get output format select of anf node
  89. static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
  90. // get input format select of anf node
  91. static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
  92. // get prev node output width output index
  93. static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
  94. // get output format from prev node,input_index is the input index of current node related to prev node
  95. static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
  96. // get output shapes inferred by ME from input nodes.
  97. static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
  98. // get input shapes inferred by ME from input nodes.
  99. static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
  100. // get output shapes which will built and run in device
  101. static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
  102. // get input shapes which will built and run in device
  103. static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
  104. // Get Input Padding Axis
  105. static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
  106. // Get Output Padding Axis
  107. static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
  108. // get output data type inferred by ME of anf node
  109. static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
  110. // get output original data type from prev node,input_index is the input index of current node related to prev node
  111. static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
  112. // get output select data type of anf node
  113. static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx);
  114. // get input select data type of anf node
  115. static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
  116. // get output select data type from prev node,input_index is the input index of current node related to prev node
  117. static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
  118. // get output device addr of anf_node
  119. static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx);
  120. // get mutable output device addr of anf_node
  121. static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx);
  122. // check whether output addr is exist or not
  123. static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
  124. // get address from prev node,input_index is the input index of current node related to prev node
  125. static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx);
  126. static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx);
  127. // set output device addr of anf_node
  128. static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
  129. // set workspace device addr of anf_node
  130. static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
  131. // get workspace device addr of anf_node
  132. static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
  133. // set infer shapes and types of anf node
  134. static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  135. const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
  136. static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
  137. // get KernelBuildType of node ,such as ATT,RT,FWK and so on
  138. static KernelType GetKernelType(const AnfNodePtr &node);
  139. // get processor type:AICORE,AICPU...
  140. static kernel::Processor GetProcessor(const AnfNodePtr &node);
  141. // get fusion type:AICORE,AICPU...
  142. static kernel::FusionType GetFusionType(const AnfNodePtr &node);
  143. // set select kernel_build_info
  144. static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node);
  145. // get select kernel_build_info
  146. static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node);
  147. // get kernelMode
  148. static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node);
  149. // set kernel mod
  150. static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node);
  151. // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
  152. static bool IsRealKernel(const AnfNodePtr &node);
  153. // checkout whether the anf node is a real kernel that is a cnode and can run on device
  154. static bool IsRealCNodeKernel(const AnfNodePtr &node);
  155. // check parameter is weight or data
  156. static bool IsParameterWeight(const ParameterPtr &node);
  157. // set stream id of kernel,which will be set in stream assign and be used in stream generate
  158. static void SetStreamId(uint32_t stream_id, AnfNode *node);
  159. // get stream id
  160. static uint32_t GetStreamId(const AnfNodePtr &node);
  161. // set stream distinction label to distinguish different ops in different streams
  162. static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node);
  163. // get stream distinction label
  164. static uint32_t GetStreamDistinctionLabel(const AnfNode *node);
  165. // set graph id
  166. static void SetGraphId(uint32_t graph_id, AnfNode *node);
  167. // get graph id
  168. static uint32_t GetGraphId(const AnfNode *node);
  169. static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
  170. // charge if the node's output is a feature map output
  171. static bool IsFeatureMapOutput(const AnfNodePtr &node);
  172. // charge if the node's input is from a feature map output
  173. static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
  174. // get real input index for some tbe ops which input order is different between me and tbe impl
  175. static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
  176. static bool IsCommunicationOp(const AnfNodePtr &node);
  177. static bool IsAllReduceOp(const AnfNodePtr &node);
  178. static bool IsGetNext(const NotNull<AnfNodePtr> &node);
  179. };
  180. } // namespace session
  181. using AnfAlgo = session::AnfRuntimeAlgorithm;
  182. } // namespace mindspore
  183. #endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H