You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

primitive.cc 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ir/primitive.h"
  17. #include <mutex>
  18. #include <utility>
  19. #include "ir/signature.h"
  20. #include "operator/ops.h"
  21. #include "./common.h"
  22. #include "pipeline/parse/python_adapter.h"
  23. #include "pipeline/parse/data_converter.h"
  24. #include "pybind11/pytypes.h"
  25. #include "utils/convert_utils.h"
  26. #include "pybind_api/api_register.h"
  27. #include "pybind_api/export_flags.h"
  28. namespace mindspore {
  29. using mindspore::abstract::AbstractFunction;
  30. abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) {
  31. auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node);
  32. return prim_func;
  33. }
  34. static py::function GetBpropFunctionByObj(py::object obj) {
  35. static const std::string get_bprop_fn = "get_bprop_fn";
  36. static const std::string ad_module = "mindspore.ops._grad";
  37. py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj);
  38. return fn;
  39. }
  40. py::function Primitive::GetBpropFunction() {
  41. auto fn = GetBpropFunctionByObj(py::str(name()));
  42. if (fn.is_none()) {
  43. MS_LOG(WARNING) << "Can't find bprop function for " << name();
  44. }
  45. return fn;
  46. }
  47. py::function Primitive::GetComputeFunction() {
  48. static const std::string module = "mindspore._extends.builtin_operations";
  49. py::module mod = py::module::import(common::SafeCStr(module));
  50. if (!py::hasattr(mod, common::SafeCStr(name()))) {
  51. PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name()));
  52. // If raise AttributeError, user can't understand. This case need raise NotImplementedError.
  53. throw py::error_already_set();
  54. }
  55. py::object fn = mod.attr(common::SafeCStr(name()));
  56. return fn;
  57. }
  58. bool Primitive::operator==(const Value &other) const {
  59. if (other.isa<Primitive>()) {
  60. auto other_prim = static_cast<const Primitive &>(other);
  61. return *this == other_prim;
  62. } else {
  63. return false;
  64. }
  65. }
  66. bool Primitive::operator==(const Primitive &other) const {
  67. if (name() != other.name()) {
  68. return false;
  69. }
  70. if (attrs_.size() != other.attrs_.size()) {
  71. return false;
  72. }
  73. auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
  74. if (item.second == nullptr) {
  75. return false;
  76. }
  77. auto iter = other.attrs_.find(item.first);
  78. if (iter == other.attrs_.end()) {
  79. return false;
  80. }
  81. return *item.second == *iter->second;
  82. });
  83. return all;
  84. }
  85. void Primitive::set_signatures(
  86. std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
  87. signatures_.clear();
  88. for (auto &signature : signatures) {
  89. std::string name;
  90. SignatureEnumRW rw;
  91. SignatureEnumKind kind;
  92. py::object default_value;
  93. SignatureEnumDType dtype;
  94. std::tie(name, rw, kind, default_value, dtype) = signature;
  95. signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype));
  96. }
  97. }
  98. std::string Primitive::GetAttrsText() const {
  99. if (attrs_.empty()) {
  100. return "";
  101. }
  102. std::ostringstream oss;
  103. oss << "[";
  104. bool is_first = true;
  105. for (auto &attr : attrs_) {
  106. if (is_first) {
  107. is_first = false;
  108. } else {
  109. oss << ", ";
  110. }
  111. oss << attr.first << "=" << attr.second->DumpText();
  112. }
  113. oss << "]";
  114. return oss.str();
  115. }
  116. py::function PrimitivePy::GetBpropFunction() {
  117. static const char *const get_bprop_func_name = "get_bprop";
  118. if (py::hasattr(python_obj_, get_bprop_func_name)) {
  119. py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
  120. return fn;
  121. } else {
  122. auto fn = GetBpropFunctionByObj(python_obj_);
  123. if (fn.is_none()) {
  124. MS_LOG(WARNING) << "Can't find bprop function for " << name();
  125. }
  126. return fn;
  127. }
  128. }
  129. py::function PrimitivePy::GetComputeFunction() {
  130. static const char *const compute_func_name = "vm_impl";
  131. if (py::hasattr(python_obj_, compute_func_name)) {
  132. MS_LOG(INFO) << name() << " compute_func_name";
  133. py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
  134. return fn;
  135. }
  136. static const std::string vm_module = "mindspore.ops.vm_impl_registry";
  137. static const std::string get_vm_impl_fn = "get_vm_impl_fn";
  138. MS_LOG(INFO) << name() << ": get_vm_impl_fn";
  139. py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
  140. py::function vm_fn = get_fn(python_obj_);
  141. if (py::isinstance<py::none>(vm_fn)) {
  142. MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
  143. vm_fn = Primitive::GetComputeFunction();
  144. }
  145. return vm_fn;
  146. }
  147. void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
  148. std::string attr_name = name;
  149. ValuePtr converted_ret = nullptr;
  150. if (py::isinstance<py::module>(obj)) {
  151. MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
  152. }
  153. bool converted = parse::ConvertData(obj, &converted_ret);
  154. if (!converted) {
  155. MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
  156. }
  157. (void)this->AddAttr(attr_name, converted_ret);
  158. }
  159. py::dict PrimitivePy::GetAttrDict() {
  160. py::dict attr_dict;
  161. for (auto &attr : attrs_) {
  162. attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
  163. }
  164. return attr_dict;
  165. }
  166. REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
  167. (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
  168. .value("unknown", PrimType::kPrimTypeUnknown)
  169. .value("builtin", PrimType::kPrimTypeBuiltIn)
  170. .value("py_infer_shape", PrimType::kPrimTypePyInferShape)
  171. .value("user_custom", PrimType::kPrimTypeUserCustom);
  172. (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
  173. .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
  174. .def(py::init<py::str &, py::object>())
  175. .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
  176. .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
  177. .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
  178. .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
  179. .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
  180. }));
  181. } // namespace mindspore