You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_build_info.cc 13 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/kernel_build_info.h"
  17. #include <algorithm>
  18. #include "utils/log_adapter.h"
  19. #include "include/common/debug/anf_dump_utils.h"
  20. namespace mindspore {
  21. namespace kernel {
  22. std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
  23. if (input_index >= inputs_format_.size()) {
  24. MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
  25. return kInvalidFormat;
  26. }
  27. return inputs_format_[input_index];
  28. }
  29. std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
  30. if (output_index >= outputs_format_.size()) {
  31. MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
  32. return kInvalidFormat;
  33. }
  34. return outputs_format_[output_index];
  35. }
  36. TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const {
  37. if (input_index >= inputs_device_type_.size()) {
  38. MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input";
  39. return TypeId::kNumberTypeEnd;
  40. }
  41. return inputs_device_type_[input_index];
  42. }
  43. TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
  44. if (output_index >= outputs_device_type_.size()) {
  45. MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
  46. return TypeId::kNumberTypeEnd;
  47. }
  48. return outputs_device_type_[output_index];
  49. }
  50. const std::string &KernelBuildInfo::GetOriginDataFormat() const { return origin_data_format_; }
  51. const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; }
  52. const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; }
  53. const std::vector<TypeId> &KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; }
  54. const std::vector<TypeId> &KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; }
  55. void KernelBuildInfo::SetOutputFormat(const std::string &format, size_t index) {
  56. if (index >= outputs_format_.size()) {
  57. MS_LOG(EXCEPTION) << "The index [" << index << "] is exceed the number of output";
  58. }
  59. outputs_format_[index] = format;
  60. }
  61. void KernelBuildInfo::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
  62. outputs_format_ = outputs_format;
  63. }
  64. void KernelBuildInfo::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
  65. if (index >= outputs_device_type_.size()) {
  66. MS_LOG(EXCEPTION) << "The index [" << index << "] is exceed the number of output";
  67. }
  68. outputs_device_type_[index] = output_device_type;
  69. }
  70. void KernelBuildInfo::SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type) {
  71. outputs_device_type_ = outputs_device_type;
  72. }
  73. size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
  74. size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
  75. size_t KernelBuildInfo::GetOutputNumWithoutMonad() const {
  76. const auto count = std::count_if(outputs_device_type_.begin(), outputs_device_type_.end(),
  77. [](TypeId type) { return type != TypeId::kObjectTypeUMonad; });
  78. return static_cast<size_t>(count);
  79. }
  80. std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
  81. if (input_reshape_type_.empty()) {
  82. return "";
  83. }
  84. if (input_index >= input_reshape_type_.size()) {
  85. MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
  86. << input_reshape_type_.size();
  87. }
  88. return input_reshape_type_[input_index];
  89. }
  90. std::string KernelBuildInfo::GetInputValueDepend(size_t input_index) const {
  91. if (input_value_depend_.empty()) {
  92. return "";
  93. }
  94. if (input_index >= input_value_depend_.size()) {
  95. MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
  96. << input_value_depend_.size();
  97. }
  98. return input_value_depend_[input_index];
  99. }
  100. std::string KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
  101. if (output_reshape_type_.empty()) {
  102. return "";
  103. }
  104. if (output_index >= output_reshape_type_.size()) {
  105. MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size "
  106. << output_reshape_type_.size();
  107. }
  108. return output_reshape_type_[output_index];
  109. }
  110. std::string KernelBuildInfo::ToString() const {
  111. std::ostringstream output_buffer;
  112. output_buffer << "(";
  113. for (size_t index = 0; index < GetInputNum(); ++index) {
  114. if (index != 0) {
  115. output_buffer << ", ";
  116. }
  117. output_buffer << "<" << TypeToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">";
  118. }
  119. output_buffer << ") -> (";
  120. for (size_t index = 0; index < GetOutputNum(); ++index) {
  121. if (index != 0) {
  122. output_buffer << ", ";
  123. }
  124. output_buffer << "<" << TypeToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">";
  125. }
  126. output_buffer << ")";
  127. return output_buffer.str();
  128. }
  129. bool KernelBuildInfo::IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const {
  130. if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) {
  131. if (op_pattern_ != kFormatAgnosticPattern) {
  132. return false;
  133. } else {
  134. MS_LOG(INFO) << "This kernel build info:" << this->ToString()
  135. << ", other kernel build info: " << other.ToString();
  136. }
  137. }
  138. return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
  139. }
  140. bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
  141. if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) {
  142. return false;
  143. }
  144. return IsSimilarityKernelBuildInfo(other);
  145. }
  146. bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); }
  147. bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
  148. bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); }
  149. void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
  150. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  151. kernel_build_info_->kernel_type_ = kernel_type;
  152. }
  153. void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) {
  154. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  155. kernel_build_info_->origin_data_format_ = origin_data_format;
  156. }
  157. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) {
  158. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  159. kernel_build_info_->inputs_format_ = inputs_format;
  160. }
  161. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
  162. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  163. kernel_build_info_->outputs_format_ = outputs_format;
  164. }
  165. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type) {
  166. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  167. kernel_build_info_->inputs_device_type_ = inputs_device_type;
  168. }
  169. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type) {
  170. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  171. kernel_build_info_->outputs_device_type_ = outputs_device_type;
  172. }
  173. void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) {
  174. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  175. kernel_build_info_->fusion_type_ = fusion_type;
  176. }
  177. void KernelBuildInfo::KernelBuildInfoBuilder::SetCoreType(const std::string &core_type) {
  178. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  179. kernel_build_info_->core_type_ = core_type;
  180. }
  181. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc) {
  182. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  183. kernel_build_info_->output_data_desc_ = data_desc;
  184. }
  185. void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) {
  186. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  187. kernel_build_info_->processor_ = processor;
  188. }
  189. std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
  190. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(const std::vector<std::string> &input_reshape_type) {
  191. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  192. kernel_build_info_->input_reshape_type_ = input_reshape_type;
  193. }
  194. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsValueDepend(const std::vector<std::string> &input_value_depend) {
  195. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  196. kernel_build_info_->input_value_depend_ = input_value_depend;
  197. }
  198. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
  199. const std::vector<std::string> &output_reshape_type) {
  200. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  201. kernel_build_info_->output_reshape_type_ = output_reshape_type;
  202. }
  203. void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) {
  204. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  205. kernel_build_info_->op_pattern_ = pattern;
  206. }
  207. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) {
  208. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  209. auto index_limit = kernel_build_info_->inputs_format_.size();
  210. if (index >= index_limit) {
  211. MS_LOG(EXCEPTION) << "Index of input format out of range! The value should be less than: " << index_limit
  212. << ", but got: " << index;
  213. }
  214. kernel_build_info_->inputs_format_[index] = format;
  215. }
  216. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) {
  217. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  218. auto index_limit = kernel_build_info_->outputs_format_.size();
  219. if (index >= index_limit) {
  220. MS_LOG(EXCEPTION) << "Index of output format out of range! The value should be less than: " << index_limit
  221. << ", but got: " << index;
  222. }
  223. kernel_build_info_->outputs_format_[index] = format;
  224. }
  225. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::string &input_reshape_type, size_t index) {
  226. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  227. auto index_limit = kernel_build_info_->input_reshape_type_.size();
  228. if (index >= index_limit) {
  229. MS_LOG(EXCEPTION) << "Index of input_reshape_type out of range! The value should be less than: " << index_limit
  230. << ", but got: " << index;
  231. }
  232. (void)std::copy(input_reshape_type.begin(), input_reshape_type.end(),
  233. std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
  234. }
  235. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::string &output_reshape_type,
  236. size_t index) {
  237. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  238. auto index_limit = kernel_build_info_->output_reshape_type_.size();
  239. if (index >= index_limit) {
  240. MS_LOG(EXCEPTION) << "Index of output_reshape_type out of range! The value should be less than: " << index_limit
  241. << ", but got: " << index;
  242. }
  243. (void)std::copy(output_reshape_type.begin(), output_reshape_type.end(),
  244. std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
  245. }
  246. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
  247. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  248. auto index_limit = kernel_build_info_->outputs_device_type_.size();
  249. if (index >= index_limit) {
  250. MS_LOG(EXCEPTION) << "Index of output_device_type out of range! The value should be less than: " << index_limit
  251. << ", but got: " << index;
  252. }
  253. kernel_build_info_->outputs_device_type_[index] = output_device_type;
  254. }
  255. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
  256. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  257. auto index_limit = kernel_build_info_->inputs_device_type_.size();
  258. if (index >= index_limit) {
  259. MS_LOG(EXCEPTION) << "Index of input_device_type out of range! The value should be less than: " << index_limit
  260. << ", but got: " << index;
  261. }
  262. kernel_build_info_->inputs_device_type_[index] = input_device_type;
  263. }
  264. } // namespace kernel
  265. } // namespace mindspore