You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_adapter_base.h 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef TRANSFORM_OP_ADAPTER_BASE_H_
  17. #define TRANSFORM_OP_ADAPTER_BASE_H_
  18. #include <unordered_map>
  19. #include <string>
  20. #include <memory>
  21. #include <utility>
  22. #include <vector>
  23. #include <sstream>
  24. #include "transform/util.h"
  25. #include "ir/anf.h"
  26. #include "ir/primitive.h"
  27. #include "ir/value.h"
  28. #include "transform/types.h"
  29. #ifdef ENABLE_GE
  30. #ifdef OPEN_SOURCE
  31. #include "graph/types.h"
  32. #endif
  33. #endif
  34. #include "graph/operator_reg.h"
  35. #ifdef OPEN_SOURCE
  36. #include "ge/client/ge_api.h"
  37. #else
  38. #include "external/ge/ge_api.h"
  39. #endif
  40. #include "graph/tensor.h"
  41. #include "transform/all_ops.h"
  42. namespace ge {
  43. class CustomOperator : public Operator {
  44. public:
  45. CustomOperator(const string &name, const string &type) : Operator(name, type) {}
  46. ~CustomOperator() override{};
  47. void CustomInputRegister(const string &name) { Operator::InputRegister(name); }
  48. void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); }
  49. void CustomInferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  50. Operator::InferFuncRegister(func);
  51. }
  52. };
  53. } // namespace ge
  54. namespace mindspore {
  55. namespace transform {
  56. using CusOperatorPtr = std::shared_ptr<ge::CustomOperator>;
  57. using CustomOperator = ge::CustomOperator;
  58. struct OutHandler {
  59. OperatorPtr op;
  60. std::string out;
  61. OutHandler() : op(nullptr), out("") {}
  62. OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {}
  63. };
  64. struct ControlEdge {
  65. OperatorPtr src_op;
  66. OperatorPtr dest_op;
  67. };
  68. using AttrFunc = std::function<void(OperatorPtr, ValuePtr)>;
  69. using OutputFunc = std::function<OutHandler(OperatorPtr)>;
  70. using InputOpFunc = std::function<void(OperatorPtr, OperatorPtr)>;
  71. using InputHandleFunc = std::function<void(OperatorPtr, OutHandler)>;
  72. using CreateDynInputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
  73. using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr)>;
  74. using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>;
  75. using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>;
  76. using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
  77. struct AttrDesc {
  78. std::string name;
  79. AttrFunc set_attr;
  80. };
  81. struct InputDesc {
  82. std::string name;
  83. InputOpFunc set_op;
  84. InputHandleFunc set_handle;
  85. UpdateOutputDescFunc update_input_desc;
  86. };
  87. struct DynInputDesc {
  88. std::string name;
  89. CreateDynInputOpFunc create_dyn_input;
  90. DynInputOpFunc set_op;
  91. DynInputHandleFunc set_handle;
  92. };
  93. struct OutputDesc {
  94. std::string name;
  95. UpdateOutputDescFunc update_out_desc;
  96. };
  97. struct DynOutputDesc {
  98. std::string name;
  99. CreateDynOutputOpFunc create_dyn_output;
  100. };
  101. class BaseOpAdapter {
  102. public:
  103. virtual ~BaseOpAdapter() {}
  104. virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
  105. virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); }
  106. virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
  107. virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
  108. virtual int setInput(const OperatorPtr &op, int index,
  109. const std::shared_ptr<std::vector<OutHandler>> &handler_vec) = 0;
  110. virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0;
  111. virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0;
  112. virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0;
  113. virtual std::unordered_map<std::string, ValuePtr> GetExtraAttr() = 0;
  114. template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
  115. int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr<T> &attrValue) {
  116. return setAttr(op, attrKey, MakeValue(attrValue));
  117. }
  118. template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
  119. int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) {
  120. return setAttr(op, attrKey, MakeValue(attrValue));
  121. }
  122. virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0;
  123. virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
  124. const AnfNodePtr &node) = 0;
  125. virtual const std::unordered_map<int, InputDesc> &getInputMap() = 0;
  126. virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0;
  127. virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0;
  128. virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0;
  129. void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
  130. const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
  131. void clearAttrVect() { attrs_vec_.clear(); }
  132. private:
  133. std::vector<std::string> attrs_vec_;
  134. };
  135. using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>;
  136. enum AttrType {
  137. ATTR_INT = 0,
  138. ATTR_FLOAT,
  139. ATTR_DOUBLE,
  140. ATTR_STRING,
  141. ATTR_TENSOR,
  142. ATTR_BOOL,
  143. ATTR_LIST_INT,
  144. ATTR_LIST_ANY_INT,
  145. ATTR_ENUM
  146. };
  147. struct GeEnum {};
  148. struct TFType {};
  149. struct GEType {};
  150. // declare Any type
  151. template <typename T>
  152. struct AnyTraits {
  153. using type = T;
  154. };
  155. template <>
  156. struct AnyTraits<int> {
  157. using type = int64_t;
  158. };
  159. using ExtraAttr = std::unordered_map<std::string, ValuePtr>;
  160. } // namespace transform
  161. } // namespace mindspore
  162. #endif // TRANSFORM_OP_ADAPTER_BASE_H_