You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_query.cc 7.8 kB

4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/kernel_query.h"
  17. #include <algorithm>
  18. #include "plugin/device/ascend/kernel/aicpu/aicpu_kernel_metadata.h"
  19. #include "plugin/device/ascend/kernel/host/host_kernel_metadata.h"
  20. #include "plugin/device/ascend/kernel/rts/rt_kernel_info.h"
  21. #include "plugin/device/ascend/kernel/hccl/hccl_kernel_metadata.h"
  22. #include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
  23. #include "kernel/akg/akg_kernel_metadata.h"
  24. #include "backend/common/session/anf_runtime_algorithm.h"
  25. #include "include/common/utils/anfalgo.h"
  26. #include "utils/ms_context.h"
  27. #include "utils/trace_base.h"
  28. namespace mindspore {
  29. namespace kernel {
  30. namespace {
  31. void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
  32. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  33. MS_EXCEPTION_IF_NULL(kernel_info_list);
  34. if (kernel_info_list->empty()) {
  35. return;
  36. }
  37. MS_EXCEPTION_IF_NULL(kernel_node);
  38. size_t output_tensor_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
  39. size_t input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
  40. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
  41. (void)std::copy_if(
  42. kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
  43. [output_tensor_num, input_tensor_num](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
  44. return kernel_build_info->GetOutputNum() == output_tensor_num &&
  45. kernel_build_info->GetInputNum() == input_tensor_num;
  46. });
  47. if (!filtered_list.empty()) {
  48. kernel_info_list->clear();
  49. (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
  50. } else {
  51. for (size_t index = 0; index < kernel_info_list->size(); ++index) {
  52. std::ostringstream buffer;
  53. auto &kernel_info = kernel_info_list->at(index);
  54. MS_EXCEPTION_IF_NULL(kernel_info);
  55. if (kernel_info->GetOutputNum() != output_tensor_num) {
  56. buffer << "Kernel node's output size [" << output_tensor_num << "]"
  57. << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
  58. } else {
  59. buffer << "Kernel node's input size [" << input_tensor_num << "]"
  60. << " cannot match the kernel's input size [" << kernel_info->GetInputNum() << "]";
  61. }
  62. MS_LOG(INFO) << "Kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
  63. }
  64. kernel_info_list->clear();
  65. MS_LOG(INFO) << "Node: " << kernel_node->DebugString() << "'s output size : [" << output_tensor_num << "]"
  66. << "input size : [" << input_tensor_num << "] can not match any kernelInfo !";
  67. }
  68. }
  69. bool SelectAicpuReshapeInTaskSink(const CNodePtr &kernel_node) {
  70. MS_EXCEPTION_IF_NULL(kernel_node);
  71. if (common::AnfAlgo::GetCNodeName(kernel_node) != "Reshape") {
  72. return false;
  73. }
  74. const size_t AicpuReshapeSize = 2;
  75. if (kernel_node->size() != AicpuReshapeSize) {
  76. return false;
  77. }
  78. auto context_ptr = MsContext::GetInstance();
  79. MS_EXCEPTION_IF_NULL(context_ptr);
  80. auto is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
  81. return is_task_sink;
  82. }
  83. } // namespace
  84. void CheckKernelInfoListEmpty(const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
  85. const std::string &type) {
  86. MS_EXCEPTION_IF_NULL(kernel_info_list);
  87. if (kernel_info_list->empty()) {
  88. MS_LOG(INFO) << "Warning: kernel info list is empty, kernel type: " << type;
  89. }
  90. }
  91. void KernelQueryAll(const CNodePtr &kernel_node,
  92. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  93. MS_EXCEPTION_IF_NULL(kernel_node);
  94. MS_EXCEPTION_IF_NULL(kernel_info_list);
  95. TbeMetadataInfo(kernel_node, kernel_info_list);
  96. if (kernel_info_list->empty()) {
  97. GetRtKelInfo(kernel_node, kernel_info_list);
  98. CheckKernelInfoListEmpty(kernel_info_list, "RT_Kernel");
  99. }
  100. if (kernel_info_list->empty()) {
  101. HcclMetadataInfo(kernel_node, kernel_info_list);
  102. CheckKernelInfoListEmpty(kernel_info_list, "HCCL_Kernel");
  103. }
  104. if (SelectAicpuReshapeInTaskSink(kernel_node)) {
  105. return;
  106. }
  107. if (kernel_info_list->empty()) {
  108. HostMetadataInfo(kernel_node, kernel_info_list);
  109. CheckKernelInfoListEmpty(kernel_info_list, "HOST_Kernel");
  110. }
  111. }
  112. void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
  113. KernelType kernel_type) {
  114. MS_EXCEPTION_IF_NULL(kernel_node);
  115. MS_EXCEPTION_IF_NULL(kernel_info_list);
  116. auto context_ptr = MsContext::GetInstance();
  117. MS_EXCEPTION_IF_NULL(context_ptr);
  118. const PrimitivePtr kPrimProdForceSeA = std::make_shared<Primitive>("ProdForceSeA");
  119. if (IsPrimitiveCNode(kernel_node, kPrimProdForceSeA)) {
  120. kernel_type = KernelType::AKG_KERNEL;
  121. }
  122. const PrimitivePtr kPrimLoadIm2Col = std::make_shared<Primitive>("LoadIm2Col");
  123. if (IsPrimitiveCNode(kernel_node, kPrimLoadIm2Col)) {
  124. kernel_type = KernelType::AKG_KERNEL;
  125. } // use LoadIm2Col only for THOR optimizer
  126. switch (kernel_type) {
  127. case KernelType::AKG_KERNEL:
  128. AkgMetadataInfo(kernel_node, kernel_info_list);
  129. break;
  130. default:
  131. KernelQueryAll(kernel_node, kernel_info_list);
  132. break;
  133. }
  134. // check output
  135. FilterInvalidKernelInfo(kernel_node, kernel_info_list);
  136. }
  137. void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  138. MS_EXCEPTION_IF_NULL(kernel_node);
  139. MS_EXCEPTION_IF_NULL(kernel_info_list);
  140. kernel_info_list->clear();
  141. AicpuMetadataInfo(kernel_node, kernel_info_list);
  142. FilterInvalidKernelInfo(kernel_node, kernel_info_list);
  143. }
  144. bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
  145. MS_EXCEPTION_IF_NULL(kernel_node);
  146. MS_EXCEPTION_IF_NULL(select_kernel_build_info);
  147. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  148. auto cnode = kernel_node->cast<CNodePtr>();
  149. MS_EXCEPTION_IF_NULL(cnode);
  150. AICPUQuery(cnode, &kernel_info_list);
  151. return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
  152. [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
  153. MS_EXCEPTION_IF_NULL(item);
  154. return item->IsSimilarityKernelBuildInfo(*select_kernel_build_info);
  155. });
  156. }
  157. bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
  158. MS_EXCEPTION_IF_NULL(kernel_node);
  159. MS_EXCEPTION_IF_NULL(select_kernel_build_info);
  160. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  161. auto cnode = kernel_node->cast<CNodePtr>();
  162. MS_EXCEPTION_IF_NULL(cnode);
  163. TbeMetadataInfo(cnode, &kernel_info_list);
  164. return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
  165. [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
  166. MS_EXCEPTION_IF_NULL(item);
  167. return *item == *select_kernel_build_info;
  168. });
  169. }
  170. } // namespace kernel
  171. } // namespace mindspore