You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_build_info.cc 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/kernel_build_info.h"
  17. #include <algorithm>
  18. #include "utils/log_adapter.h"
  19. #include "debug/anf_ir_dump.h"
  20. namespace mindspore {
  21. namespace kernel {
  22. std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
  23. if (input_index >= inputs_format_.size()) {
  24. MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
  25. return kInvalidFormat;
  26. }
  27. return inputs_format_[input_index];
  28. }
  29. std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
  30. if (output_index >= outputs_format_.size()) {
  31. MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node";
  32. return kInvalidFormat;
  33. }
  34. return outputs_format_[output_index];
  35. }
  36. TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const {
  37. if (input_index >= inputs_device_type_.size()) {
  38. MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input";
  39. return TypeId::kNumberTypeEnd;
  40. }
  41. return inputs_device_type_[input_index];
  42. }
  43. TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
  44. if (output_index >= outputs_device_type_.size()) {
  45. MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
  46. return TypeId::kNumberTypeEnd;
  47. }
  48. return outputs_device_type_[output_index];
  49. }
  50. std::vector<std::string> KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; }
  51. std::vector<std::string> KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; }
  52. std::vector<TypeId> KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; }
  53. std::vector<TypeId> KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; }
  54. size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
  55. size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
  56. std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
  57. if (input_index >= input_reshape_type_.size()) {
  58. MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
  59. << input_reshape_type_.size();
  60. }
  61. return input_reshape_type_[input_index];
  62. }
  63. std::vector<Axis> KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
  64. if (output_index >= output_reshape_type_.size()) {
  65. MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size "
  66. << output_reshape_type_.size();
  67. }
  68. return output_reshape_type_[output_index];
  69. }
  70. std::string KernelBuildInfo::ToString() const {
  71. std::ostringstream output_buffer;
  72. output_buffer << "(";
  73. for (size_t index = 0; index < GetInputNum(); ++index) {
  74. if (index != 0) {
  75. output_buffer << ", ";
  76. }
  77. output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">";
  78. }
  79. output_buffer << ") -> (";
  80. for (size_t index = 0; index < GetOutputNum(); ++index) {
  81. if (index != 0) {
  82. output_buffer << ", ";
  83. }
  84. output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">";
  85. }
  86. output_buffer << ")";
  87. return output_buffer.str();
  88. }
  89. bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
  90. if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) {
  91. return false;
  92. }
  93. if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) {
  94. if (op_pattern_ != kFormatAgnosticPattern) {
  95. return false;
  96. } else {
  97. MS_LOG(INFO) << "this kernel build info:" << this->ToString()
  98. << ", other kernel build info: " << other.ToString();
  99. }
  100. }
  101. return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
  102. }
  103. bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); }
  104. bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
  105. bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); }
  106. void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
  107. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  108. kernel_build_info_->kernel_type_ = kernel_type;
  109. }
  110. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) {
  111. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  112. kernel_build_info_->inputs_format_ = inputs_format;
  113. }
  114. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
  115. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  116. kernel_build_info_->outputs_format_ = outputs_format;
  117. }
  118. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type) {
  119. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  120. kernel_build_info_->inputs_device_type_ = inputs_device_type;
  121. }
  122. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type) {
  123. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  124. kernel_build_info_->outputs_device_type_ = outputs_device_type;
  125. }
  126. void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) {
  127. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  128. kernel_build_info_->fusion_type_ = fusion_type;
  129. }
  130. void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) {
  131. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  132. kernel_build_info_->processor_ = processor;
  133. }
  134. std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
  135. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(
  136. const std::vector<std::vector<Axis>> &input_reshape_type) {
  137. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  138. kernel_build_info_->input_reshape_type_ = input_reshape_type;
  139. }
  140. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(
  141. const std::vector<std::vector<Axis>> &output_reshape_type) {
  142. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  143. kernel_build_info_->output_reshape_type_ = output_reshape_type;
  144. }
  145. void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) {
  146. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  147. kernel_build_info_->op_pattern_ = pattern;
  148. }
  149. void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) {
  150. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  151. if (index >= kernel_build_info_->inputs_format_.size()) {
  152. MS_LOG(EXCEPTION) << "index outof range!";
  153. }
  154. kernel_build_info_->inputs_format_[index] = format;
  155. }
  156. void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) {
  157. MS_EXCEPTION_IF_NULL(kernel_build_info_);
  158. if (index >= kernel_build_info_->outputs_format_.size()) {
  159. MS_LOG(EXCEPTION) << "index outof range!";
  160. }
  161. kernel_build_info_->outputs_format_[index] = format;
  162. }
  163. } // namespace kernel
  164. } // namespace mindspore