You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rt_kernel_info.cc 3.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/rts/rt_kernel_info.h"
  17. #include <unordered_map>
  18. #include <algorithm>
  19. #include "utils/convert_utils.h"
  20. #include "utils/utils.h"
  21. #include "common/utils.h"
  22. #include "session/anf_runtime_algorithm.h"
  23. namespace mindspore {
  24. namespace kernel {
  25. void RtKerDescFactory::Register(const std::string &name, RtKerDescCreater &&fun) {
  26. if (fmap_.find(name) == fmap_.end()) {
  27. (void)fmap_.emplace(name, std::move(fun));
  28. }
  29. }
  30. std::shared_ptr<RtKerDesc> RtKerDescFactory::Create(const std::string &name) {
  31. const auto &map = Get().fmap_;
  32. auto it = map.find(name);
  33. if (it != map.end() && it->second) {
  34. return (it->second)();
  35. }
  36. return nullptr;
  37. }
  38. RtKerDescFactory &RtKerDescFactory::Get() {
  39. static RtKerDescFactory _this;
  40. return _this;
  41. }
  42. static bool IsDefaultKernelInfo(const std::string &name) {
  43. static const std::set<std::string> white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName,
  44. kLabelGotoOpName};
  45. return white_list.find(name) != white_list.end();
  46. }
  47. void GetRtKelInfo(const CNodePtr &kernel_node,
  48. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  49. MS_EXCEPTION_IF_NULL(kernel_info_list);
  50. MS_EXCEPTION_IF_NULL(kernel_node);
  51. std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node);
  52. (void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower);
  53. auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower);
  54. if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) {
  55. *kernel_info_list = ker_desc_ptr->GetKernelInfo();
  56. return;
  57. }
  58. // if can't find kernel info in kernel info database, use the default kernel info
  59. auto node_name = AnfAlgo::GetCNodeName(kernel_node);
  60. if (IsDefaultKernelInfo(node_name)) {
  61. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  62. // set input infos
  63. auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  64. kernel_build_info_builder->SetInputsFormat(std::vector<std::string>(input_num, kOpFormat_DEFAULT));
  65. std::vector<TypeId> input_types = {};
  66. for (size_t i = 0; i < input_num; i++) {
  67. input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i));
  68. }
  69. kernel_build_info_builder->SetInputsDeviceType(input_types);
  70. // set output info
  71. auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  72. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
  73. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown));
  74. // set ohter info
  75. kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
  76. kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
  77. kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
  78. kernel_info_list->push_back(kernel_build_info_builder->Build());
  79. return;
  80. }
  81. MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "].";
  82. }
  83. } // namespace kernel
  84. } // namespace mindspore