You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_query.cc 7.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/kernel_compiler/kernel_query.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h"
  20. #include "backend/kernel_compiler/host/host_kernel_metadata.h"
  21. #include "backend/kernel_compiler/rts/rt_kernel_info.h"
  22. #include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h"
  23. #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
  24. #include "backend/kernel_compiler/akg/akg_kernel_metadata.h"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. #include "utils/ms_context.h"
  27. #include "utils/trace_base.h"
  28. namespace mindspore {
  29. namespace kernel {
  30. namespace {
  31. void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
  32. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  33. MS_EXCEPTION_IF_NULL(kernel_info_list);
  34. MS_EXCEPTION_IF_NULL(kernel_node);
  35. size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  36. size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
  37. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
  38. (void)std::copy_if(
  39. kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
  40. [output_tensor_num, input_tensor_num](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
  41. return kernel_build_info->GetOutputNum() == output_tensor_num &&
  42. kernel_build_info->GetInputNum() == input_tensor_num;
  43. });
  44. if (!filtered_list.empty()) {
  45. kernel_info_list->clear();
  46. (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
  47. } else {
  48. MS_LOG(INFO) << "All kernel Info list does not match any kernel info ";
  49. for (size_t index = 0; index < kernel_info_list->size(); ++index) {
  50. std::ostringstream buffer;
  51. auto &kernel_info = kernel_info_list->at(index);
  52. MS_EXCEPTION_IF_NULL(kernel_info);
  53. if (kernel_info->GetOutputNum() != output_tensor_num) {
  54. buffer << "Kernel node's output size [" << output_tensor_num << "]"
  55. << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
  56. } else {
  57. buffer << "Kernel node's output size [" << input_tensor_num << "]"
  58. << " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]";
  59. }
  60. MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
  61. }
  62. kernel_info_list->clear();
  63. MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" << output_tensor_num << "]"
  64. << "input size : [" << input_tensor_num << "] cannot match any kernelInfo !";
  65. }
  66. }
  67. } // namespace
  68. void KernelQueryAll(const CNodePtr &kernel_node,
  69. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  70. MS_EXCEPTION_IF_NULL(kernel_node);
  71. MS_EXCEPTION_IF_NULL(kernel_info_list);
  72. std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
  73. TbeMetadataInfo(kernel_node, kernel_info_list);
  74. if (kernel_info_list->empty()) {
  75. AicpuMetadataInfo(kernel_node, kernel_info_list);
  76. if (!kernel_info_list->empty()) {
  77. MS_LOG(INFO) << "The node [" << kernel_node->DebugString()
  78. << "] cannot find valid TBE kernel info, try to get aicpu kernel info";
  79. AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
  80. }
  81. }
  82. if (kernel_info_list->empty()) {
  83. GetRtKelInfo(kernel_node, kernel_info_list);
  84. }
  85. if (kernel_info_list->empty()) {
  86. HcclMetadataInfo(kernel_node, kernel_info_list);
  87. }
  88. if (kernel_info_list->empty()) {
  89. HostMetadataInfo(kernel_node, kernel_info_list);
  90. }
  91. if (kernel_info_list->empty()) {
  92. MS_EXCEPTION(NotExistsError) << "Can not find any available operator info for op [" << op_name << ", "
  93. << kernel_node->fullname_with_scope()
  94. << "]. Node DebugString:" << kernel_node->DebugString()
  95. << ", maybe the operator can not supported on current platform. \n trace "
  96. << trace::DumpSourceLines(kernel_node);
  97. }
  98. }
  99. void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
  100. KernelType kernel_type) {
  101. MS_EXCEPTION_IF_NULL(kernel_node);
  102. MS_EXCEPTION_IF_NULL(kernel_info_list);
  103. auto context_ptr = MsContext::GetInstance();
  104. MS_EXCEPTION_IF_NULL(context_ptr);
  105. const PrimitivePtr kPrimProdForceSeA = std::make_shared<Primitive>("ProdForceSeA");
  106. if (IsPrimitiveCNode(kernel_node, kPrimProdForceSeA)) {
  107. kernel_type = KernelType::AKG_KERNEL;
  108. }
  109. switch (kernel_type) {
  110. case KernelType::AKG_KERNEL:
  111. AkgMetadataInfo(kernel_node, kernel_info_list);
  112. break;
  113. default:
  114. KernelQueryAll(kernel_node, kernel_info_list);
  115. break;
  116. }
  117. if (kernel_info_list->empty()) {
  118. MS_EXCEPTION(NotExistsError) << "Can not find any available operator info for op ["
  119. << AnfAlgo::GetCNodeName(kernel_node) << ", " << kernel_node->fullname_with_scope()
  120. << "]. Node DebugString:" << kernel_node->DebugString()
  121. << ", maybe the operator can not supported on current platform. \n trace "
  122. << trace::DumpSourceLines(kernel_node);
  123. }
  124. // check output
  125. FilterInvalidKernelInfo(kernel_node, kernel_info_list);
  126. }
  127. void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  128. MS_EXCEPTION_IF_NULL(kernel_node);
  129. MS_EXCEPTION_IF_NULL(kernel_info_list);
  130. kernel_info_list->clear();
  131. AicpuMetadataInfo(kernel_node, kernel_info_list);
  132. FilterInvalidKernelInfo(kernel_node, kernel_info_list);
  133. }
  134. bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
  135. MS_EXCEPTION_IF_NULL(kernel_node);
  136. MS_EXCEPTION_IF_NULL(select_kernel_build_info);
  137. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  138. auto cnode = kernel_node->cast<CNodePtr>();
  139. MS_EXCEPTION_IF_NULL(cnode);
  140. AICPUQuery(cnode, &kernel_info_list);
  141. return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
  142. [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
  143. MS_EXCEPTION_IF_NULL(item);
  144. return item->IsSimilarityKernelBuildInfo(*select_kernel_build_info);
  145. });
  146. }
  147. bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
  148. MS_EXCEPTION_IF_NULL(kernel_node);
  149. MS_EXCEPTION_IF_NULL(select_kernel_build_info);
  150. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  151. auto cnode = kernel_node->cast<CNodePtr>();
  152. MS_EXCEPTION_IF_NULL(cnode);
  153. TbeMetadataInfo(cnode, &kernel_info_list);
  154. return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
  155. [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
  156. MS_EXCEPTION_IF_NULL(item);
  157. return *item == *select_kernel_build_info;
  158. });
  159. }
  160. } // namespace kernel
  161. } // namespace mindspore