You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 4.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_
  17. #define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include <utility>
  22. #include "common/graph_kernel/model/lite_graph.h"
  23. #include "common/graph_kernel/model/node.h"
  24. namespace mindspore::graphkernel::expanders {
  25. using inner::NodePtr;
  26. using inner::NodePtrList;
  27. using BaseInfoList = std::vector<inner::NodeBase>;
  28. class Validator;
  29. class OpDesc {
  30. public:
  31. inner::LiteGraphPtr Run(const BaseInfoList &inputs, const BaseInfoList &outputs, const inner::DAttrs &attrs,
  32. const std::string &processor);
  33. const std::string &Op() const { return op_; }
  34. const BaseInfoList &InputsInfo() const { return inputs_info_; }
  35. const BaseInfoList &OuputsInfo() const { return outputs_info_; }
  36. inner::DAttrs Attrs() const { return attrs_; }
  37. const std::string &Processor() const { return processor_; }
  38. virtual ~OpDesc() = default;
  39. protected:
  40. virtual bool CheckInputs() { return true; }
  41. virtual NodePtrList Expand() = 0;
  42. bool CheckOutputs();
  43. inner::LiteGraph::GraphBuilder gb;
  44. std::string op_;
  45. BaseInfoList inputs_info_;
  46. BaseInfoList outputs_info_;
  47. inner::DAttrs attrs_;
  48. std::string processor_;
  49. std::vector<std::unique_ptr<Validator>> validators_;
  50. friend class OpExpanderFactory;
  51. };
  52. class Validator {
  53. public:
  54. virtual bool Check(const OpDesc &e) = 0;
  55. virtual ~Validator() = default;
  56. };
  57. class CheckAllFormatsSame : public Validator {
  58. public:
  59. bool Check(const OpDesc &e) override {
  60. const auto &inputs_info = e.InputsInfo();
  61. if (inputs_info.empty()) return true;
  62. const auto &fmt_0 = inputs_info[0].format;
  63. for (size_t i = 1; i < inputs_info.size(); i++) {
  64. if (inputs_info[i].format != fmt_0) {
  65. MS_LOG(INFO) << "Unmatched format for op " << e.Op();
  66. return false;
  67. }
  68. }
  69. return true;
  70. }
  71. };
  72. class CheckAttr : public Validator {
  73. public:
  74. CheckAttr(std::initializer_list<std::string> l) : attrs_(std::move(l)) {}
  75. virtual ~CheckAttr() = default;
  76. bool Check(const OpDesc &e) override {
  77. for (auto &a : attrs_) {
  78. if (e.Attrs().count(a) == 0) {
  79. MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.Op();
  80. return false;
  81. }
  82. }
  83. return true;
  84. }
  85. private:
  86. std::vector<std::string> attrs_;
  87. };
  88. class SupportFormat : public Validator {
  89. public:
  90. void AddFormat(std::initializer_list<std::string> l) { (void)formats_.emplace_back(l); }
  91. bool Check(const OpDesc &e) override {
  92. for (auto &formats : formats_) {
  93. if (formats.size() != e.InputsInfo().size()) {
  94. continue;
  95. }
  96. bool match = true;
  97. for (size_t i = 0; i < formats.size(); i++) {
  98. if (e.InputsInfo()[i].format != formats[i]) {
  99. match = false;
  100. break;
  101. }
  102. }
  103. if (match) {
  104. return true;
  105. }
  106. }
  107. MS_LOG(INFO) << "unsupported format for op " << e.Op();
  108. return false;
  109. }
  110. virtual ~SupportFormat() = default;
  111. private:
  112. std::vector<std::vector<std::string>> formats_;
  113. };
  114. std::vector<int64_t> GetAxisList(const ValuePtr &value);
  115. ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis);
  116. NodePtr ReluExpand(const inner::LiteGraph::GraphBuilder &gb, const NodePtrList &inputs);
  117. NodePtr SigmoidExpand(const inner::LiteGraph::GraphBuilder &gb, const NodePtrList &inputs);
  118. } // namespace mindspore::graphkernel::expanders
  119. #endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_