You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_query.cc 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/kernel_query.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include "kernel/aicpu/aicpu_kernel_metadata.h"
  20. #include "kernel/rts/rt_kernel_info.h"
  21. #include "kernel/hccl/hccl_kernel_metadata.h"
  22. #include "kernel/tbe/tbe_kernel_select.h"
  23. #include "session/anf_runtime_algorithm.h"
  24. namespace mindspore {
  25. namespace kernel {
  26. namespace {
  27. void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
  28. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  29. MS_EXCEPTION_IF_NULL(kernel_info_list);
  30. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
  31. (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
  32. [&](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
  33. return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
  34. AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
  35. });
  36. if (!filtered_list.empty()) {
  37. kernel_info_list->clear();
  38. (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
  39. } else {
  40. MS_LOG(WARNING) << "All kernel Info list does not match any kernel info ";
  41. for (size_t index = 0; index < kernel_info_list->size(); ++index) {
  42. MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
  43. MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
  44. }
  45. MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : ["
  46. << AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
  47. << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
  48. }
  49. }
  50. } // namespace
  51. void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  52. MS_EXCEPTION_IF_NULL(kernel_node);
  53. MS_EXCEPTION_IF_NULL(kernel_info_list);
  54. TbeMetadataInfo(kernel_node, kernel_info_list);
  55. if (kernel_info_list->empty()) {
  56. AicpuMetadataInfo(kernel_node, kernel_info_list);
  57. }
  58. if (kernel_info_list->empty()) {
  59. GetRtKelInfo(kernel_node, kernel_info_list);
  60. }
  61. if (kernel_info_list->empty()) {
  62. HcclMetadataInfo(kernel_node, kernel_info_list);
  63. }
  64. if (kernel_info_list->empty()) {
  65. MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!";
  66. }
  67. FilterInvalidKernelInfo(kernel_node, kernel_info_list);
  68. }
  69. void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  70. MS_EXCEPTION_IF_NULL(kernel_node);
  71. MS_EXCEPTION_IF_NULL(kernel_info_list);
  72. kernel_info_list->clear();
  73. AicpuMetadataInfo(kernel_node, kernel_info_list);
  74. FilterInvalidKernelInfo(kernel_node, kernel_info_list);
  75. }
  76. bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
  77. MS_EXCEPTION_IF_NULL(kernel_node);
  78. MS_EXCEPTION_IF_NULL(select_kernel_build_info);
  79. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  80. auto cnode = kernel_node->cast<CNodePtr>();
  81. MS_EXCEPTION_IF_NULL(cnode);
  82. AICPUQuery(cnode, &kernel_info_list);
  83. return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
  84. [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
  85. MS_EXCEPTION_IF_NULL(item);
  86. return *item == *select_kernel_build_info;
  87. });
  88. }
  89. bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
  90. MS_EXCEPTION_IF_NULL(kernel_node);
  91. MS_EXCEPTION_IF_NULL(select_kernel_build_info);
  92. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  93. auto cnode = kernel_node->cast<CNodePtr>();
  94. MS_EXCEPTION_IF_NULL(cnode);
  95. TbeMetadataInfo(cnode, &kernel_info_list);
  96. FilterInvalidKernelInfo(cnode, &kernel_info_list);
  97. return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
  98. [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
  99. MS_EXCEPTION_IF_NULL(item);
  100. return *item == *select_kernel_build_info;
  101. });
  102. }
  103. } // namespace kernel
  104. } // namespace mindspore