You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_build_info.h 7.3 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
  17. #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
  18. #include <iostream>
  19. #include <vector>
  20. #include <memory>
  21. #include <string>
  22. #include <utility>
  23. #include "ir/dtype.h"
  24. #include "ir/kernel_info_dev.h"
  25. #include "kernel/kernel.h"
  26. namespace mindspore {
  27. namespace kernel {
  28. class KernelBuildInfo {
  29. public:
  30. class KernelBuildInfoBuilder;
  31. KernelBuildInfo() {
  32. kernel_type_ = TBE_KERNEL;
  33. fusion_type_ = OPAQUE;
  34. processor_ = AICORE;
  35. op_pattern_ = kCommonPattern;
  36. core_type_ = "";
  37. input_reshape_type_ = {};
  38. output_reshape_type_ = {};
  39. origin_data_format_ = kOpFormat_DEFAULT;
  40. inputs_format_ = {};
  41. outputs_format_ = {};
  42. inputs_device_type_ = {};
  43. outputs_device_type_ = {};
  44. output_data_desc_ = {};
  45. }
  46. ~KernelBuildInfo() = default;
  47. KernelType kernel_type() const { return kernel_type_; }
  48. std::string GetInputFormat(size_t input_index) const;
  49. std::string GetOutputFormat(size_t output_index) const;
  50. TypeId GetInputDeviceType(size_t input_index) const;
  51. TypeId GetOutputDeviceType(size_t output_index) const;
  52. std::string GetInputReshapeType(size_t input_index) const;
  53. std::string GetInputValueDepend(size_t input_index) const;
  54. bool IsInputDefaultPadding() const;
  55. bool IsOutputDefaultPadding() const;
  56. std::string GetOutputReshapeType(size_t input_index) const;
  57. const std::string &GetOriginDataFormat() const;
  58. const std::vector<std::string> &GetAllInputFormats() const;
  59. const std::vector<std::string> &GetAllOutputFormats() const;
  60. const std::vector<TypeId> &GetAllInputDeviceTypes() const;
  61. const std::vector<TypeId> &GetAllOutputDeviceTypes() const;
  62. std::vector<std::string> GetAllOutputReshapeType() const;
  63. std::vector<std::string> GetAllInputReshapeType() const;
  64. std::string core_type() const { return core_type_; }
  65. void SetOutputFormat(const std::string &format, size_t index);
  66. void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
  67. void SetOutputsFormat(const std::vector<std::string> &outputs_format);
  68. void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
  69. OpPattern op_pattern() const { return op_pattern_; }
  70. std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; }
  71. FusionType fusion_type() const { return fusion_type_; }
  72. Processor processor() const { return processor_; }
  73. size_t GetInputNum() const;
  74. size_t GetOutputNum() const;
  75. size_t GetOutputNumWithoutMonad() const;
  76. std::string ToString() const;
  77. bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const;
  78. bool operator==(const KernelBuildInfo &other) const;
  79. bool operator!=(const KernelBuildInfo &other) const;
  80. static auto constexpr kInvalidFormat = "InvalidFormat";
  81. private:
  82. KernelType kernel_type_;
  83. std::string origin_data_format_;
  84. std::string core_type_;
  85. std::vector<std::string> inputs_format_;
  86. OpPattern op_pattern_;
  87. std::vector<std::string> outputs_format_;
  88. std::vector<std::string> input_reshape_type_;
  89. std::vector<std::string> output_reshape_type_;
  90. std::vector<TypeId> inputs_device_type_;
  91. std::vector<TypeId> outputs_device_type_;
  92. std::vector<nlohmann::json> output_data_desc_;
  93. std::vector<std::string> input_value_depend_;
  94. FusionType fusion_type_;
  95. Processor processor_;
  96. };
  97. using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
  98. class KernelBuildInfo::KernelBuildInfoBuilder {
  99. public:
  100. KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
  101. explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info)
  102. : kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
  103. SetKernelType(kernel_build_info->kernel_type());
  104. SetFusionType(kernel_build_info->fusion_type());
  105. SetProcessor(kernel_build_info->processor());
  106. SetOpPattern(kernel_build_info->op_pattern());
  107. SetCoreType(kernel_build_info->core_type());
  108. SetOutputDataDesc(kernel_build_info->output_data_desc());
  109. for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
  110. (void)kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
  111. (void)kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
  112. (void)kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
  113. (void)kernel_build_info_->input_value_depend_.emplace_back(kernel_build_info->GetInputValueDepend(index));
  114. }
  115. for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
  116. (void)kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
  117. (void)kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
  118. (void)kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
  119. }
  120. }
  121. ~KernelBuildInfoBuilder() = default;
  122. void SetKernelType(const KernelType &kernel_type);
  123. void SetOriginDataFormat(const std::string &origin_data_format);
  124. void SetInputsFormat(const std::vector<std::string> &inputs_format);
  125. void SetOutputsFormat(const std::vector<std::string> &outputs_format);
  126. void SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type);
  127. void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
  128. void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type);
  129. void SetInputsValueDepend(const std::vector<std::string> &input_value_depend);
  130. void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type);
  131. void SetCoreType(const std::string &core_type);
  132. void SetFusionType(FusionType fusion_type);
  133. // save prebuild result
  134. void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc);
  135. void SetProcessor(Processor processor);
  136. void SetOpPattern(OpPattern pattern);
  137. void SetInputFormat(const std::string &format, size_t index);
  138. void SetOutputFormat(const std::string &format, size_t index);
  139. void SetInputReshapeType(const std::string &input_reshape_type, size_t index);
  140. void SetOutputReshapeType(const std::string &output_reshape_type, size_t index);
  141. void SetInputDeviceType(const TypeId &input_device_type, size_t index);
  142. void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
  143. std::shared_ptr<KernelBuildInfo> Build();
  144. private:
  145. std::shared_ptr<KernelBuildInfo> kernel_build_info_;
  146. };
  147. } // namespace kernel
  148. } // namespace mindspore
  149. #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_