You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_query.cc 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/kernel_query.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include "kernel/aicpu/aicpu_kernel_metadata.h"
  20. #include "kernel/mng/rt_kernel_info.h"
  21. #include "kernel/hccl/hccl_kernel_metadata.h"
  22. #include "kernel/tbe/tbe_kernel_select.h"
  23. #include "session/anf_runtime_algorithm.h"
  24. namespace mindspore {
  25. namespace kernel {
  26. namespace {
  27. void FilterInvaildKernelInfo(const CNodePtr &kernel_node,
  28. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  29. MS_EXCEPTION_IF_NULL(kernel_info_list);
  30. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
  31. (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
  32. [&](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
  33. return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
  34. AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
  35. });
  36. kernel_info_list->clear();
  37. if (!filtered_list.empty()) {
  38. (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
  39. } else {
  40. MS_LOG(EXCEPTION) << "node" << kernel_node->DebugString() << "'s output size : ["
  41. << AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
  42. << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node)
  43. << "] cannot match any kernelInfo !";
  44. }
  45. }
  46. } // namespace
  47. void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  48. MS_EXCEPTION_IF_NULL(kernel_node);
  49. MS_EXCEPTION_IF_NULL(kernel_info_list);
  50. TbeMetadataInfo(kernel_node, kernel_info_list);
  51. if (kernel_info_list->empty()) {
  52. AicpuMetadataInfo(kernel_node, kernel_info_list);
  53. }
  54. if (kernel_info_list->empty()) {
  55. GetRtKelInfo(kernel_node, kernel_info_list);
  56. }
  57. if (kernel_info_list->empty()) {
  58. HcclMetadataInfo(kernel_node, kernel_info_list);
  59. }
  60. if (kernel_info_list->empty()) {
  61. MS_LOG(EXCEPTION) << "op" << kernel_node->DebugString() << "kernel query fail!";
  62. }
  63. FilterInvaildKernelInfo(kernel_node, kernel_info_list);
  64. }
  65. } // namespace kernel
  66. } // namespace mindspore