You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_build_info.h 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
  17. #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
  18. #include <iostream>
  19. #include <vector>
  20. #include <memory>
  21. #include <string>
  22. #include <utility>
  23. #include "ir/dtype.h"
  24. #include "ir/kernel_info_dev.h"
  25. #include "kernel/kernel.h"
  26. namespace mindspore {
  27. namespace kernel {
  28. class KernelBuildInfo {
  29. public:
  30. class KernelBuildInfoBuilder;
  31. KernelBuildInfo() {
  32. kernel_type_ = TBE_KERNEL;
  33. fusion_type_ = OPAQUE;
  34. processor_ = AICORE;
  35. op_pattern_ = kCommonPattern;
  36. input_reshape_type_ = {};
  37. output_reshape_type_ = {};
  38. origin_data_format_ = kOpFormat_DEFAULT;
  39. inputs_format_ = {};
  40. outputs_format_ = {};
  41. inputs_device_type_ = {};
  42. outputs_device_type_ = {};
  43. output_data_desc_ = {};
  44. }
  45. ~KernelBuildInfo() = default;
  46. KernelType kernel_type() const { return kernel_type_; }
  47. std::string GetInputFormat(size_t input_index) const;
  48. std::string GetOutputFormat(size_t output_index) const;
  49. TypeId GetInputDeviceType(size_t input_index) const;
  50. TypeId GetOutputDeviceType(size_t output_index) const;
  51. std::string GetInputReshapeType(size_t input_index) const;
  52. std::string GetInputValueDepend(size_t input_index) const;
  53. bool IsInputDefaultPadding() const;
  54. bool IsOutputDefaultPadding() const;
  55. std::string GetOutputReshapeType(size_t input_index) const;
  56. const std::string &GetOriginDataFormat() const;
  57. const std::vector<std::string> &GetAllInputFormats() const;
  58. const std::vector<std::string> &GetAllOutputFormats() const;
  59. const std::vector<TypeId> &GetAllInputDeviceTypes() const;
  60. const std::vector<TypeId> &GetAllOutputDeviceTypes() const;
  61. std::vector<std::string> GetAllOutputReshapeType() const;
  62. std::vector<std::string> GetAllInputReshapeType() const;
  63. OpPattern op_pattern() const { return op_pattern_; }
  64. std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; }
  65. FusionType fusion_type() const { return fusion_type_; }
  66. Processor processor() const { return processor_; }
  67. size_t GetInputNum() const;
  68. size_t GetOutputNum() const;
  69. size_t GetOutputNumWithoutMonad() const;
  70. std::string ToString() const;
  71. bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const;
  72. bool operator==(const KernelBuildInfo &other) const;
  73. bool operator!=(const KernelBuildInfo &other) const;
  74. static auto constexpr kInvalidFormat = "InvalidFormat";
  75. private:
  76. KernelType kernel_type_;
  77. std::string origin_data_format_;
  78. std::vector<std::string> inputs_format_;
  79. OpPattern op_pattern_;
  80. std::vector<std::string> outputs_format_;
  81. std::vector<std::string> input_reshape_type_;
  82. std::vector<std::string> output_reshape_type_;
  83. std::vector<TypeId> inputs_device_type_;
  84. std::vector<TypeId> outputs_device_type_;
  85. std::vector<nlohmann::json> output_data_desc_;
  86. std::vector<std::string> input_value_depend_;
  87. FusionType fusion_type_;
  88. Processor processor_;
  89. };
  90. using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
  91. class KernelBuildInfo::KernelBuildInfoBuilder {
  92. public:
  93. KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
  94. explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info)
  95. : kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
  96. SetKernelType(kernel_build_info->kernel_type());
  97. SetFusionType(kernel_build_info->fusion_type());
  98. SetProcessor(kernel_build_info->processor());
  99. SetOpPattern(kernel_build_info->op_pattern());
  100. for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
  101. kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
  102. kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
  103. kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
  104. kernel_build_info_->input_value_depend_.emplace_back(kernel_build_info->GetInputValueDepend(index));
  105. }
  106. for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
  107. kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
  108. kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
  109. kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
  110. }
  111. }
  112. ~KernelBuildInfoBuilder() = default;
  113. void SetKernelType(const KernelType &kernel_type);
  114. void SetOriginDataFormat(const std::string &origin_data_format);
  115. void SetInputsFormat(const std::vector<std::string> &inputs_format);
  116. void SetOutputsFormat(const std::vector<std::string> &outputs_format);
  117. void SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type);
  118. void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
  119. void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type);
  120. void SetInputsValueDepend(const std::vector<std::string> &input_value_depend);
  121. void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type);
  122. void SetFusionType(FusionType fusion_type);
  123. // save prebuild result
  124. void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc);
  125. void SetProcessor(Processor processor);
  126. void SetOpPattern(OpPattern pattern);
  127. void SetInputFormat(const std::string &format, size_t index);
  128. void SetOutputFormat(const std::string &format, size_t index);
  129. void SetInputReshapeType(const std::string &input_reshape_type, size_t index);
  130. void SetOutputReshapeType(const std::string &output_reshape_type, size_t index);
  131. void SetInputDeviceType(const TypeId &input_device_type, size_t index);
  132. void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
  133. std::shared_ptr<KernelBuildInfo> Build();
  134. private:
  135. std::shared_ptr<KernelBuildInfo> kernel_build_info_;
  136. };
  137. } // namespace kernel
  138. } // namespace mindspore
  139. #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_