You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rt_kernel_info.cc 3.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/mng/rt_kernel_info.h"
  17. #include <unordered_map>
  18. #include <algorithm>
  19. #include "utils/convert_utils.h"
  20. #include "utils/utils.h"
  21. #include "common/utils.h"
  22. #include "session/anf_runtime_algorithm.h"
  23. namespace mindspore {
  24. namespace kernel {
  25. void RtKerDescFactory::Register(const std::string &name, RtKerDescCreater &&fun) {
  26. if (fmap_.find(name) == fmap_.end()) {
  27. (void)fmap_.emplace(name, std::move(fun));
  28. }
  29. }
  30. std::shared_ptr<RtKerDesc> RtKerDescFactory::Create(const std::string &name) {
  31. const auto &map = Get().fmap_;
  32. auto it = map.find(name);
  33. if (it != map.end() && it->second) {
  34. return (it->second)();
  35. }
  36. return nullptr;
  37. }
  38. RtKerDescFactory &RtKerDescFactory::Get() {
  39. static RtKerDescFactory _this;
  40. return _this;
  41. }
  42. void GetRtKelInfo(const CNodePtr &kernel_node,
  43. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  44. MS_EXCEPTION_IF_NULL(kernel_info_list);
  45. MS_EXCEPTION_IF_NULL(kernel_node);
  46. std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node);
  47. (void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower);
  48. auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower);
  49. if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) {
  50. *kernel_info_list = ker_desc_ptr->GetKernelInfo();
  51. return;
  52. }
  53. // if can't find kernel info in kernel info database, use the default kernel info
  54. auto node_name = AnfAlgo::GetCNodeName(kernel_node);
  55. if (node_name == "StreamSwitch" || node_name == "StreamActive") {
  56. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  57. // set input infos
  58. auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  59. kernel_build_info_builder->SetInputsFormat(std::vector<std::string>(input_num, kOpFormat_DEFAULT));
  60. std::vector<TypeId> input_types = {};
  61. for (size_t i = 0; i < input_num; i++) {
  62. input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i));
  63. }
  64. kernel_build_info_builder->SetInputsDeviceType(input_types);
  65. // set output info
  66. auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  67. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
  68. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown));
  69. // set ohter info
  70. kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
  71. kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
  72. kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
  73. kernel_info_list->push_back(kernel_build_info_builder->Build());
  74. return;
  75. }
  76. MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "].";
  77. }
  78. } // namespace kernel
  79. } // namespace mindspore