You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_helper.h 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_
  17. #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "device/ascend/kernel_select_ascend.h"
  22. #include "kernel/kernel_query.h"
  23. #include "kernel/tbe/tbe_kernel_select.h"
  24. namespace mindspore {
  25. namespace opt {
  26. class KernelSelect {
  27. public:
  28. KernelSelect() = default;
  29. virtual ~KernelSelect() = default;
  30. virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); }
  31. virtual bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
  32. const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
  33. return device::ascend::CheckKernelAccuracySupported(kernel_node, new_kernel_build_info);
  34. }
  35. };
  36. using KernelSelectPtr = std::shared_ptr<KernelSelect>;
  37. class SupportedChecker {
  38. public:
  39. SupportedChecker() = default;
  40. virtual ~SupportedChecker() = default;
  41. virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
  42. return kernel::CheckSupported(anf_node, select_kernel_build_info);
  43. }
  44. };
  45. using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
  46. class KernelQuery {
  47. public:
  48. KernelQuery() = default;
  49. virtual ~KernelQuery() = default;
  50. virtual void Query(const CNodePtr &kernel_node,
  51. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
  52. kernel::KernelQuery(kernel_node, kernel_info_list);
  53. }
  54. };
  55. using KernelQueryPtr = std::shared_ptr<KernelQuery>;
  56. AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  57. const KernelSelectPtr &kernel_select, size_t insert_index,
  58. const std::string &origin_format, const std::string &dest_format,
  59. const std::string &op_name, bool is_insert_input);
  60. AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
  61. const TypeId &input_type, const TypeId &output_type,
  62. const std::vector<size_t> &origin_shape, const TypeId &origin_type);
  63. AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  64. const KernelSelectPtr &kernel_select);
  65. AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  66. const KernelSelectPtr &kernel_select);
  67. CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
  68. AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node);
  69. } // namespace opt
  70. } // namespace mindspore
  71. #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_