You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

primitive.h 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_H_
  17. #define MINDSPORE_CCSRC_IR_PRIMITIVE_H_
  18. #include <unordered_map>
  19. #include <vector>
  20. #include <memory>
  21. #include <string>
  22. #include <tuple>
  23. #include "pybind11/pybind11.h"
  24. #include "pipeline/static_analysis/abstract_value.h"
  25. #include "utils/misc.h"
  26. #include "utils/log_adapter.h"
  27. #include "ir/signature.h"
  28. #include "parallel/ops_info/operator_info.h"
  29. namespace py = pybind11;
  30. namespace mindspore {
  31. using abstract::AbstractBasePtr;
  32. using abstract::AbstractBasePtrList;
  33. // Supported meta type
  34. enum PrimType {
  35. kPrimTypeUnknown = 0,
  36. kPrimTypeBegin = kTypeUnknown,
  37. kPrimTypeBuiltIn, // Built-in primitive operator
  38. kPrimTypePyInferShape, // Primitive operator defined by custom
  39. kPrimTypePyInferTensor, // Primitive operator defined by custom
  40. kPrimTypeUserCustom
  41. };
  42. class Primitive : public Named {
  43. public:
  44. explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn)
  45. : Named(name), signatures_(), prim_type_(prim_type) {}
  46. Primitive(const Primitive &prim)
  47. : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {}
  48. MS_DECLARE_PARENT(Primitive, Named);
  49. abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
  50. std::string ToString() const override { return name(); }
  51. virtual py::function GetBpropFunction();
  52. virtual py::function GetComputeFunction();
  53. Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
  54. attrs_[name] = attr;
  55. return *this;
  56. }
  57. Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
  58. for (auto &attr : attrs) {
  59. attrs_[attr.first] = attr.second;
  60. }
  61. return *this;
  62. }
  63. void set_signatures(
  64. std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
  65. signatures);
  66. const std::vector<Signature> &signatures() const { return signatures_; }
  67. void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
  68. void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
  69. ValuePtr GetAttr(const std::string &attrName) const {
  70. auto iter = attrs_.find(attrName);
  71. return iter == attrs_.cend() ? nullptr : iter->second;
  72. }
  73. const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
  74. // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
  75. bool HasAttr() const { return !attrs_.empty(); }
  76. bool HasAttr(const std::string &attrName) const {
  77. auto iter = attrs_.find(attrName);
  78. return !(iter == attrs_.cend());
  79. }
  80. void set_prim_type(const PrimType t) { prim_type_ = t; }
  81. void set_instance_name(const std::string s) { instance_name_ = s; }
  82. bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
  83. bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
  84. bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
  85. PrimType prim_type() const { return prim_type_; }
  86. std::string instance_name() const { return instance_name_; }
  87. std::string GetAttrsText() const;
  88. bool operator==(const Value &other) const override;
  89. bool operator==(const Primitive &other) const;
  90. ~Primitive() override = default;
  91. protected:
  92. std::unordered_map<std::string, ValuePtr> attrs_;
  93. private:
  94. std::vector<Signature> signatures_;
  95. std::string instance_name_;
  96. PrimType prim_type_;
  97. };
  98. class PrimitivePy : public Primitive {
  99. public:
  100. PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {}
  101. ~PrimitivePy() override = default;
  102. MS_DECLARE_PARENT(PrimitivePy, Primitive);
  103. py::function GetBpropFunction() override;
  104. py::function GetComputeFunction() override;
  105. void AddPyAttr(const py::str &name, const py::object &obj);
  106. py::dict GetAttrDict();
  107. const bool parse_info_ = true;
  108. const py::object &GetPyObj() const { return python_obj_; }
  109. bool is_tuple_input_ = false;
  110. private:
  111. py::object python_obj_;
  112. };
  113. using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
  114. inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
  115. os << *p;
  116. return os;
  117. }
  118. struct PrimitiveEqual {
  119. bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
  120. MS_EXCEPTION_IF_NULL(t1);
  121. MS_EXCEPTION_IF_NULL(t2);
  122. return t1->name() == t2->name();
  123. }
  124. };
  125. struct PrimitiveHasher {
  126. std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); }
  127. };
  128. } // namespace mindspore
  129. #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_