You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

primitive_py.cc 14 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pybind_api/ir/primitive_py.h"
  17. #include <mutex>
  18. #include "ir/signature.h"
  19. #include "pipeline/jit/parse/data_converter.h"
  20. #include "pipeline/jit/parse/python_adapter.h"
  21. #include "pybind11/pytypes.h"
  22. #include "pybind_api/api_register.h"
  23. #include "pybind_api/export_flags.h"
  24. #include "pybind_api/ir/base_ref_py.h"
  25. #include "utils/convert_utils_base.h"
  26. #include "utils/convert_utils_py.h"
  27. #include "utils/ms_context.h"
  28. #include "utils/primitive_utils.h"
  29. namespace mindspore {
  30. namespace {
  31. constexpr auto kBpropAttrName = "bprop";
  32. constexpr auto kCellHookAttrName = "cell_hook";
  33. constexpr auto kCellIDAttrName = "cell_id";
  34. void SyncData(const py::object &arg) {
  35. if (py::isinstance<py::tuple>(arg)) {
  36. py::tuple arg_list = py::cast<py::tuple>(arg);
  37. for (size_t i = 0; i < arg_list.size(); i++) {
  38. SyncData(arg_list[i]);
  39. }
  40. }
  41. if (py::isinstance<tensor::Tensor>(arg)) {
  42. auto tensor = py::cast<tensor::TensorPtr>(arg);
  43. (void)tensor->data_sync();
  44. }
  45. }
  46. } // namespace
  47. std::map<std::string, py::object> PrimitivePy::hook_grad_;
  48. void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
  49. signatures_ = signatures;
  50. set_has_signature(true);
  51. }
  52. py::function PrimitivePy::GetBpropFunction() {
  53. static const char *const get_bprop_func_name = "get_bprop";
  54. if (py::hasattr(python_obj_, get_bprop_func_name)) {
  55. py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
  56. return fn;
  57. } else {
  58. auto fn = GetBpropFunctionByObj(python_obj_);
  59. return fn;
  60. }
  61. }
  62. py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) {
  63. py::tuple grads;
  64. if (!py::isinstance<py::tuple>(grads_obj)) {
  65. grads = py::make_tuple(grads_obj);
  66. } else {
  67. grads = py::cast<py::tuple>(grads_obj);
  68. }
  69. if (grads.size() != py_args.size() - 2) {
  70. MS_EXCEPTION(ValueError) << "For user define net bprop, the gradients number: " << grads.size()
  71. << " is not equal to the args number: " << py_args.size() - 2 << ".";
  72. }
  73. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
  74. for (size_t i = 0; i < grads.size(); i++) {
  75. if (py::isinstance<tensor::Tensor>(py_args[i])) {
  76. if (!py::isinstance<tensor::Tensor>(grads[i])) {
  77. MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i
  78. << "th arg should be Tensor, but got "
  79. << py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
  80. << ", and the value is " << py::cast<py::str>(grads[i]) << ".";
  81. }
  82. py::object arg_dtype = py_args[i].attr("dtype");
  83. py::object grad_dtype = grads[i].attr("dtype");
  84. py::tuple arg_shape = py_args[i].attr("shape");
  85. py::tuple grad_shape = grads[i].attr("shape");
  86. if (!grad_dtype.equal(arg_dtype)) {
  87. MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
  88. << "th arg should have the same dtype as the " << i << "th arg, but the " << i
  89. << "th arg dtype is: " << py::cast<py::str>(arg_dtype)
  90. << ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
  91. }
  92. if (!grad_shape.equal(arg_shape)) {
  93. MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
  94. << "th arg should have the same shape as the " << i << "th arg, but the " << i
  95. << "th arg shape is: " << py::cast<py::str>(arg_shape)
  96. << ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
  97. }
  98. }
  99. }
  100. }
  101. return grads;
  102. }
  103. void PrimitivePy::ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const {
  104. MS_EXCEPTION_IF_NULL(convert_args);
  105. if (input_args.size() != (*convert_args).size()) {
  106. MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size()
  107. << " should be equal to the size of convert_args: " << (*convert_args).size();
  108. }
  109. for (size_t i = 0; i < input_args.size(); ++i) {
  110. (*convert_args)[i] = py::isinstance<tensor::Tensor>(input_args[i])
  111. ? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
  112. parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i])
  113. : input_args[i];
  114. }
  115. }
  116. void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const {
  117. if (py::isinstance<py::tuple>(expected_grad_out)) {
  118. if (!py::isinstance<py::tuple>(grad_out)) {
  119. hook_grad_.clear();
  120. MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!";
  121. }
  122. auto actual_out_tuple = py::cast<py::tuple>(grad_out);
  123. auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out);
  124. if (actual_out_tuple.size() != expected_out_tuple.size()) {
  125. hook_grad_.clear();
  126. MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size()
  127. << ", but it is " << actual_out_tuple.size();
  128. }
  129. for (size_t i = 0; i < expected_out_tuple.size(); ++i) {
  130. CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i]);
  131. }
  132. }
  133. if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
  134. if (!py::isinstance<tensor::Tensor>(grad_out)) {
  135. hook_grad_.clear();
  136. MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!";
  137. }
  138. auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
  139. auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
  140. MS_EXCEPTION_IF_NULL(actual_out_tensor);
  141. MS_EXCEPTION_IF_NULL(expected_out_tensor);
  142. if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
  143. hook_grad_.clear();
  144. MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be "
  145. << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is "
  146. << actual_out_tensor->GetShapeAndDataTypeInfo();
  147. }
  148. }
  149. }
  150. BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
  151. py::tuple py_args = ConvertDatatoPyTuple(args);
  152. bool is_bprop = this->HasAttr(kBpropAttrName);
  153. if (is_bprop) {
  154. SyncData(py_args);
  155. py::tuple convert_args(py_args.size());
  156. ConvertCTensorToPyTensor(py_args, &convert_args);
  157. py::object grads_obj = hook_(*convert_args);
  158. py::tuple grads = check_bprop_out(grads_obj, py_args);
  159. return std::make_shared<PyObjectRef>(grads);
  160. }
  161. SyncData(py_args[2]);
  162. bool is_cell = this->HasAttr(kCellHookAttrName);
  163. py::object obj;
  164. if (is_cell) {
  165. auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
  166. auto iter = hook_grad_.find(cell_id);
  167. if (iter != hook_grad_.end()) {
  168. py::tuple convert_args(2);
  169. py::tuple input_args(2);
  170. input_args[0] = iter->second;
  171. input_args[1] = py_args[2];
  172. ConvertCTensorToPyTensor(input_args, &convert_args);
  173. auto hook_args = py::tuple(3);
  174. hook_args[0] = cell_id;
  175. hook_args[1] = py::make_tuple(convert_args[0]);
  176. hook_args[2] = py::make_tuple(convert_args[1]);
  177. obj = hook_(*hook_args);
  178. if (py::isinstance<py::none>(obj)) {
  179. obj = py_args[2];
  180. }
  181. CheckHookConsistency(obj, py_args[2]);
  182. hook_grad_.erase(cell_id);
  183. } else {
  184. hook_grad_[cell_id] = py_args[2];
  185. obj = py_args[2];
  186. }
  187. } else {
  188. // Hook operator for execute variable hook function
  189. obj = hook_(py::make_tuple(py_args[2]));
  190. if (py::isinstance<py::none>(obj)) {
  191. obj = py_args[2];
  192. }
  193. CheckHookConsistency(obj, py_args[2]);
  194. }
  195. obj = py::make_tuple(obj);
  196. return std::make_shared<PyObjectRef>(obj);
  197. }
  198. py::function PrimitivePy::GetComputeFunction() const {
  199. static const char *const compute_func_name = "vm_impl";
  200. if (py::hasattr(python_obj_, compute_func_name)) {
  201. MS_LOG(INFO) << name() << " compute_func_name";
  202. py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
  203. return fn;
  204. }
  205. static const std::string vm_module = "mindspore.ops.vm_impl_registry";
  206. static const std::string get_vm_impl_fn = "get_vm_impl_fn";
  207. MS_LOG(INFO) << name() << ": get_vm_impl_fn";
  208. py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
  209. py::function vm_fn = get_fn(python_obj_);
  210. if (py::isinstance<py::none>(vm_fn)) {
  211. MS_LOG(INFO) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
  212. vm_fn = mindspore::GetComputeFunction(Primitive::name());
  213. }
  214. return vm_fn;
  215. }
  216. void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
  217. std::string attr_name = name;
  218. ValuePtr converted_ret = nullptr;
  219. if (py::isinstance<py::module>(obj)) {
  220. MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
  221. }
  222. bool converted = parse::ConvertData(obj, &converted_ret);
  223. if (!converted) {
  224. MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
  225. }
  226. (void)this->AddAttr(attr_name, converted_ret);
  227. }
  228. py::dict PrimitivePy::GetAttrDict() {
  229. py::dict attr_dict;
  230. for (auto &attr : attrs_) {
  231. attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
  232. }
  233. return attr_dict;
  234. }
  235. void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
  236. MS_EXCEPTION_IF_NULL(primitive);
  237. if (!primitive->isa<PrimitivePy>()) {
  238. MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!";
  239. }
  240. auto primitive_py = primitive->cast<PrimitivePyPtr>();
  241. MS_EXCEPTION_IF_NULL(primitive_py);
  242. this->set_hook(primitive_py->hook());
  243. }
  244. BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
  245. auto py_args = ConvertDatatoPyTuple(args);
  246. auto result = this->RunPyComputeFunction(py_args);
  247. if (py::isinstance<py::none>(result)) {
  248. return std::make_shared<BaseRef>(nullptr);
  249. }
  250. return std::make_shared<PyObjectRef>(result);
  251. }
  252. py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
  253. auto func = this->GetComputeFunction();
  254. if (py::isinstance<py::none>(func)) {
  255. return py::none();
  256. }
  257. auto result = func(*py_args);
  258. return result;
  259. }
  260. bool PrimitivePy::HasComputeFunction() const {
  261. auto func = GetComputeFunction();
  262. return !py::isinstance<py::none>(func);
  263. }
  264. PrimitivePtr PrimitivePy::Clone() {
  265. auto clone_fn = python_obj_.attr("_clone");
  266. py::object new_obj = clone_fn();
  267. auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
  268. return cloned_prim;
  269. }
  270. py::dict PrimitivePy::RunInfer(const py::tuple &args) {
  271. if (!HasPyObj()) {
  272. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  273. }
  274. auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER);
  275. return infer_fuc(*args);
  276. }
  277. void PrimitivePy::RunCheck(const py::tuple &args) {
  278. if (!HasPyObj()) {
  279. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  280. }
  281. auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK);
  282. (void)check_func(*args);
  283. }
  284. py::object PrimitivePy::RunInferValue(const py::tuple &args) {
  285. if (!HasPyObj()) {
  286. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  287. }
  288. auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE);
  289. return infer_value(*args);
  290. }
  291. REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
  292. (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
  293. .value("unknown", PrimType::kPrimTypeUnknown)
  294. .value("builtin", PrimType::kPrimTypeBuiltIn)
  295. .value("py_infer_shape", PrimType::kPrimTypePyInferShape)
  296. .value("user_custom", PrimType::kPrimTypeUserCustom)
  297. .value("py_infer_check", PrimType::kPrimTypePyInferCheck);
  298. (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
  299. .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
  300. .def(py::init<py::str &, py::object>())
  301. .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
  302. .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
  303. .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
  304. .def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.")
  305. .def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes,
  306. "Set primitive const input indexes.")
  307. .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
  308. .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
  309. .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
  310. }));
  311. } // namespace mindspore