You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
  17. #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include "common/graph_kernel/model/lite_graph.h"
  22. #include "common/graph_kernel/model/node.h"
  23. namespace mindspore::graphkernel::expanders {
  24. using inner::NodePtrList;
  25. using BaseInfoList = std::vector<inner::NodeBase>;
  26. class Validator;
  27. class OpDesc {
  28. public:
  29. inner::LiteGraphPtr Run(const BaseInfoList &inputs, const BaseInfoList &outputs, const inner::DAttrs &attrs,
  30. const std::string &processor);
  31. virtual ~OpDesc() = default;
  32. protected:
  33. virtual bool CheckInputs() { return true; }
  34. virtual NodePtrList Expand() = 0;
  35. bool CheckOutputs();
  36. inner::LiteGraph::GraphBuilder gb;
  37. std::string op_;
  38. BaseInfoList inputs_info_;
  39. BaseInfoList outputs_info_;
  40. inner::DAttrs attrs_;
  41. std::string processor_;
  42. std::vector<std::unique_ptr<Validator>> validators_;
  43. friend class OpExpanderFactory;
  44. friend class CheckAllFormatsSame;
  45. friend class CheckAttr;
  46. friend class SupportFormat;
  47. };
  48. class Validator {
  49. public:
  50. virtual bool Check(const OpDesc &e) = 0;
  51. };
  52. class CheckAllFormatsSame : public Validator {
  53. public:
  54. bool Check(const OpDesc &e) override {
  55. if (e.inputs_info_.empty()) return true;
  56. const auto &fmt_0 = e.inputs_info_[0].format;
  57. for (size_t i = 1; i < e.inputs_info_.size(); i++) {
  58. if (e.inputs_info_[i].format != fmt_0) {
  59. MS_LOG(INFO) << "Unmatched format for op " << e.op_;
  60. return false;
  61. }
  62. }
  63. return true;
  64. }
  65. };
  66. class CheckAttr : public Validator {
  67. public:
  68. CheckAttr(std::initializer_list<std::string> l) : attrs_(l) {}
  69. ~CheckAttr() = default;
  70. bool Check(const OpDesc &e) override {
  71. for (auto &a : attrs_) {
  72. if (e.attrs_.count(a) == 0) {
  73. MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.op_;
  74. return false;
  75. }
  76. }
  77. return true;
  78. }
  79. private:
  80. std::vector<std::string> attrs_;
  81. };
  82. class SupportFormat : public Validator {
  83. public:
  84. void AddFormat(std::initializer_list<std::string> l) { formats_.emplace_back(l); }
  85. bool Check(const OpDesc &e) override {
  86. for (auto &formats : formats_) {
  87. if (formats.size() != e.inputs_info_.size()) {
  88. continue;
  89. }
  90. bool match = true;
  91. for (size_t i = 0; i < formats.size(); i++) {
  92. if (e.inputs_info_[i].format != formats[i]) {
  93. match = false;
  94. break;
  95. }
  96. }
  97. if (match) {
  98. return true;
  99. }
  100. }
  101. MS_LOG(INFO) << "unsupported format for op " << e.op_;
  102. return false;
  103. }
  104. private:
  105. std::vector<std::vector<std::string>> formats_;
  106. };
  107. std::vector<int64_t> GetAxisList(const ValuePtr &value);
  108. ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis);
  109. } // namespace mindspore::graphkernel::expanders
  110. #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_