You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

primitive_py.cc 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/primitive_py.h"
  17. #include <mutex>
  18. #include "ir/signature.h"
  19. #include "./common.h"
  20. #include "pipeline/jit/parse/python_adapter.h"
  21. #include "pipeline/jit/parse/data_converter.h"
  22. #include "pybind11/pytypes.h"
  23. #include "utils/convert_utils_base.h"
  24. #include "utils/primitive_utils.h"
  25. #include "utils/base_ref_py.h"
  26. #include "pybind_api/api_register.h"
  27. #include "pybind_api/export_flags.h"
  28. namespace mindspore {
  29. namespace {
  30. constexpr auto kBpropAttrName = "bprop";
  31. constexpr auto kCellHookAttrName = "cell_hook";
  32. constexpr auto kCellIDAttrName = "cell_id";
  33. void SyncData(const py::object &arg) {
  34. if (py::isinstance<py::tuple>(arg)) {
  35. py::tuple arg_list = py::cast<py::tuple>(arg);
  36. for (size_t i = 0; i < arg_list.size(); i++) {
  37. SyncData(arg_list[i]);
  38. }
  39. }
  40. if (py::isinstance<tensor::Tensor>(arg)) {
  41. auto tensor = py::cast<tensor::TensorPtr>(arg);
  42. (void)tensor->data_sync();
  43. }
  44. }
  45. } // namespace
  46. std::map<std::string, py::object> PrimitivePy::hook_grad_;
  47. static ValuePtr PyArgToValue(const py::object &arg) {
  48. if (py::isinstance<SignatureEnumKind>(arg) &&
  49. py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
  50. return nullptr;
  51. }
  52. return parse::data_converter::PyDataToValue(arg);
  53. }
  54. void PrimitivePy::set_signatures(
  55. std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
  56. signatures_.clear();
  57. for (auto &signature : signatures) {
  58. auto [name, rw, kind, arg_default, dtype] = signature;
  59. auto default_value = PyArgToValue(arg_default);
  60. signatures_.emplace_back(name, rw, kind, default_value, dtype);
  61. }
  62. set_has_signature(true);
  63. }
  64. py::function PrimitivePy::GetBpropFunction() {
  65. static const char *const get_bprop_func_name = "get_bprop";
  66. if (py::hasattr(python_obj_, get_bprop_func_name)) {
  67. py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
  68. return fn;
  69. } else {
  70. auto fn = GetBpropFunctionByObj(python_obj_);
  71. return fn;
  72. }
  73. }
  74. BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
  75. auto py_args = ConvertDatatoPyTuple(args);
  76. py::object obj;
  77. bool is_bprop = this->HasAttr(kBpropAttrName);
  78. if (is_bprop) {
  79. SyncData(py_args);
  80. obj = hook_(*py_args);
  81. return std::make_shared<PyObjectRef>(obj);
  82. }
  83. SyncData(py_args[2]);
  84. bool is_cell = this->HasAttr(kCellHookAttrName);
  85. if (is_cell) {
  86. auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
  87. auto iter = hook_grad_.find(cell_id);
  88. if (iter != hook_grad_.end()) {
  89. auto hook_args = py::tuple(3);
  90. hook_args[0] = cell_id;
  91. hook_args[1] = py::make_tuple(iter->second);
  92. hook_args[2] = py::make_tuple(py_args[2]);
  93. obj = hook_(*hook_args);
  94. if (py::isinstance<py::none>(obj)) {
  95. obj = py_args[2];
  96. }
  97. hook_grad_.erase(cell_id);
  98. } else {
  99. hook_grad_[cell_id] = py_args[2];
  100. obj = py_args[2];
  101. }
  102. } else {
  103. // Hook operator for execute variable hook function
  104. obj = hook_(py::make_tuple(py_args[2]));
  105. if (py::isinstance<py::none>(obj)) {
  106. obj = py_args[2];
  107. }
  108. }
  109. obj = py::make_tuple(obj);
  110. return std::make_shared<PyObjectRef>(obj);
  111. }
  112. py::function PrimitivePy::GetComputeFunction() const {
  113. static const char *const compute_func_name = "vm_impl";
  114. if (py::hasattr(python_obj_, compute_func_name)) {
  115. MS_LOG(INFO) << name() << " compute_func_name";
  116. py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
  117. return fn;
  118. }
  119. static const std::string vm_module = "mindspore.ops.vm_impl_registry";
  120. static const std::string get_vm_impl_fn = "get_vm_impl_fn";
  121. MS_LOG(INFO) << name() << ": get_vm_impl_fn";
  122. py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
  123. py::function vm_fn = get_fn(python_obj_);
  124. if (py::isinstance<py::none>(vm_fn)) {
  125. MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
  126. vm_fn = mindspore::GetComputeFunction(Primitive::name());
  127. }
  128. return vm_fn;
  129. }
  130. void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
  131. std::string attr_name = name;
  132. ValuePtr converted_ret = nullptr;
  133. if (py::isinstance<py::module>(obj)) {
  134. MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
  135. }
  136. bool converted = parse::ConvertData(obj, &converted_ret);
  137. if (!converted) {
  138. MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
  139. }
  140. (void)this->AddAttr(attr_name, converted_ret);
  141. }
  142. py::dict PrimitivePy::GetAttrDict() {
  143. py::dict attr_dict;
  144. for (auto &attr : attrs_) {
  145. attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
  146. }
  147. return attr_dict;
  148. }
  149. void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
  150. MS_EXCEPTION_IF_NULL(primitive);
  151. if (!primitive->isa<PrimitivePy>()) {
  152. MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!";
  153. }
  154. auto primitive_py = primitive->cast<PrimitivePyPtr>();
  155. MS_EXCEPTION_IF_NULL(primitive_py);
  156. this->set_hook(primitive_py->hook());
  157. }
  158. BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
  159. auto py_args = ConvertDatatoPyTuple(args);
  160. auto result = this->RunPyComputeFunction(py_args);
  161. if (py::isinstance<py::none>(result)) {
  162. return std::make_shared<BaseRef>(nullptr);
  163. }
  164. return std::make_shared<PyObjectRef>(result);
  165. }
  166. py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
  167. auto func = this->GetComputeFunction();
  168. if (py::isinstance<py::none>(func)) {
  169. return py::none();
  170. }
  171. auto result = func(*py_args);
  172. return result;
  173. }
  174. bool PrimitivePy::HasComputeFunction() const {
  175. auto func = GetComputeFunction();
  176. if (py::isinstance<py::none>(func)) {
  177. return false;
  178. }
  179. return true;
  180. }
  181. PrimitivePtr PrimitivePy::Clone() {
  182. auto clone_fn = python_obj_.attr("_clone");
  183. py::object new_obj = clone_fn();
  184. auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
  185. return cloned_prim;
  186. }
  187. py::dict PrimitivePy::RunInfer(const py::tuple &args) {
  188. if (!HasPyObj()) {
  189. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  190. }
  191. auto infer_fuc = python_obj_.attr("__infer__");
  192. return infer_fuc(*args);
  193. }
  194. REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
  195. (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
  196. .value("unknown", PrimType::kPrimTypeUnknown)
  197. .value("builtin", PrimType::kPrimTypeBuiltIn)
  198. .value("py_infer_shape", PrimType::kPrimTypePyInferShape)
  199. .value("user_custom", PrimType::kPrimTypeUserCustom);
  200. (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
  201. .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
  202. .def(py::init<py::str &, py::object>())
  203. .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
  204. .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
  205. .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
  206. .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
  207. .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
  208. .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
  209. }));
  210. } // namespace mindspore