You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_helper.h 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_
  17. #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "runtime/device/ascend/kernel_select_ascend.h"
  22. #include "backend/kernel_compiler/kernel_query.h"
  23. #include "backend/kernel_compiler/oplib/oplib.h"
  24. #include "backend/session/anf_runtime_algorithm.h"
  25. #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
  26. namespace mindspore {
  27. namespace opt {
  28. class KernelSelect {
  29. public:
  30. KernelSelect() = default;
  31. virtual ~KernelSelect() = default;
  32. virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); }
  33. };
  34. using KernelSelectPtr = std::shared_ptr<KernelSelect>;
  35. class SupportedChecker {
  36. public:
  37. SupportedChecker() = default;
  38. virtual ~SupportedChecker() = default;
  39. virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node,
  40. const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
  41. return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
  42. }
  43. virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node,
  44. const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
  45. return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
  46. }
  47. };
  48. using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
  49. class KernelQuery {
  50. public:
  51. KernelQuery() = default;
  52. virtual ~KernelQuery() = default;
  53. virtual void Query(const CNodePtr &kernel_node,
  54. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  55. kernel::KernelQuery(kernel_node, kernel_info_list);
  56. }
  57. virtual bool IsTbeRef(const AnfNodePtr &node) {
  58. MS_EXCEPTION_IF_NULL(node);
  59. if (!node->isa<CNode>()) {
  60. return false;
  61. }
  62. auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(node), node);
  63. if (op_info != nullptr) {
  64. return op_info->is_ref();
  65. }
  66. return false;
  67. }
  68. };
  69. using KernelQueryPtr = std::shared_ptr<KernelQuery>;
  70. class OpFinder {
  71. public:
  72. OpFinder() = default;
  73. virtual ~OpFinder() = default;
  74. virtual int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) {
  75. auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
  76. if (op_info == nullptr) {
  77. return -1;
  78. }
  79. return op_info->outputs_ptr().size();
  80. }
  81. };
  82. using OpFinderPtr = std::shared_ptr<OpFinder>;
  83. void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
  84. const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type = {},
  85. const TypeId &type_id = kTypeUnknown);
  86. CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
  87. const bool need_padding, const std::string &op_name);
  88. AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
  89. const TypeId &input_type, const TypeId &output_type,
  90. const std::vector<size_t> &origin_shape, const TypeId &origin_type);
  91. AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  92. const KernelSelectPtr &kernel_select);
  93. AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  94. const KernelSelectPtr &kernel_select);
  95. CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
  96. AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node);
  97. } // namespace opt
  98. } // namespace mindspore
  99. #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_