You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

primitive_py.cc 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pybind_api/ir/primitive_py.h"
  17. #include <mutex>
  18. #include <map>
  19. #include <utility>
  20. #include "ir/signature.h"
  21. #include "pipeline/jit/parse/data_converter.h"
  22. #include "include/common/utils/python_adapter.h"
  23. #include "pybind11/pytypes.h"
  24. #include "pybind_api/api_register.h"
  25. #include "pybind_api/export_flags.h"
  26. #include "pybind_api/ir/base_ref_py.h"
  27. #include "utils/convert_utils_base.h"
  28. #include "include/common/utils/convert_utils_py.h"
  29. #include "utils/ms_context.h"
  30. #include "include/common/utils/primitive_utils.h"
  31. #include "utils/check_convert_utils.h"
  32. #include "pipeline/pynative/pynative_execute.h"
  33. namespace mindspore {
  34. namespace {
  35. constexpr auto kBpropAttrName = "bprop";
  36. constexpr auto kCellHookAttrName = "cell_hook";
  37. constexpr auto kCellIDAttrName = "cell_id";
  38. std::map<std::string, std::string> kOpAttrNameReplaceMap = {
  39. {"data_format", "format"},
  40. };
  41. void SyncData(const py::object &arg) {
  42. if (py::isinstance<py::tuple>(arg)) {
  43. py::tuple arg_list = py::cast<py::tuple>(arg);
  44. for (size_t i = 0; i < arg_list.size(); i++) {
  45. SyncData(arg_list[i]);
  46. }
  47. }
  48. if (py::isinstance<tensor::Tensor>(arg)) {
  49. auto tensor = py::cast<tensor::TensorPtr>(arg);
  50. tensor->data_sync();
  51. }
  52. }
  53. void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) {
  54. MS_EXCEPTION_IF_NULL(convert_args);
  55. if (input_args.size() != (*convert_args).size()) {
  56. MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size()
  57. << " should be equal to the size of convert_args: " << (*convert_args).size();
  58. }
  59. for (size_t i = 0; i < input_args.size(); ++i) {
  60. if (py::isinstance<tensor::Tensor>(input_args[i])) {
  61. (*convert_args)[i] =
  62. python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]);
  63. } else if (py::isinstance<py::tuple>(input_args[i])) {
  64. auto tuple_inp_arg = py::cast<py::tuple>(input_args[i]);
  65. py::tuple convert_tuple_arg(tuple_inp_arg.size());
  66. ConvertCTensorToPyTensor(tuple_inp_arg, &convert_tuple_arg);
  67. (*convert_args)[i] = convert_tuple_arg;
  68. } else {
  69. (*convert_args)[i] = input_args[i];
  70. }
  71. }
  72. }
  73. py::tuple ConstructCellHookFnArgs(const std::string &cell_id, const py::object &grad_input,
  74. const py::object &grad_output) {
  75. constexpr size_t grad_input_index = 1;
  76. constexpr size_t grad_output_index = 2;
  77. constexpr size_t input_args_nums = 3;
  78. // Convert c++ object to python object.
  79. py::tuple c_grad_args(input_args_nums - 1);
  80. c_grad_args[0] = grad_input;
  81. c_grad_args[1] = grad_output;
  82. py::tuple py_grad_args(input_args_nums - 1);
  83. ConvertCTensorToPyTensor(c_grad_args, &py_grad_args);
  84. // Get tuple args of cell hook function.
  85. py::tuple hook_fn_args(input_args_nums);
  86. hook_fn_args[0] = cell_id;
  87. if (!py::isinstance<py::tuple>(py_grad_args[0])) {
  88. hook_fn_args[grad_input_index] = py::make_tuple(py_grad_args[0]);
  89. } else {
  90. hook_fn_args[grad_input_index] = py_grad_args[0];
  91. }
  92. if (!py::isinstance<py::tuple>(py_grad_args[1])) {
  93. hook_fn_args[grad_output_index] = py::make_tuple(py_grad_args[1]);
  94. } else {
  95. hook_fn_args[grad_output_index] = py_grad_args[1];
  96. }
  97. return hook_fn_args;
  98. }
  99. } // namespace
  100. std::map<std::string, py::object> PrimitivePy::hook_grad_;
  101. PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {}
  102. PrimitivePy::PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter)
  103. : Primitive(adapter->name_, false), python_obj_(python_obj), adapter_(adapter) {
  104. MS_LOG(DEBUG) << "New primitive:" << adapter->name_;
  105. set_signatures(adapter->signatures_);
  106. (void)Primitive::SetAttrs(adapter->attrs_);
  107. Primitive::set_prim_type(adapter->prim_type_);
  108. Primitive::set_const_prim(adapter->is_const_prim_);
  109. Primitive::set_const_input_indexes(adapter->const_input_indexes_);
  110. for (const auto &elem : adapter->backward_hook_fn_) {
  111. AddBackwardHookFn(elem.first, elem.second);
  112. }
  113. set_instance_name(adapter->instance_name_);
  114. }
  115. PrimitivePy::~PrimitivePy() {}
  116. void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
  117. signatures_ = signatures;
  118. set_has_signature(!signatures.empty());
  119. }
  120. py::function PrimitivePy::GetBpropFunction() {
  121. static const char *const get_bprop_func_name = "get_bprop";
  122. if (py::hasattr(python_obj_, get_bprop_func_name)) {
  123. py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
  124. return fn;
  125. } else {
  126. auto fn = GetBpropFunctionByObj(python_obj_);
  127. return fn;
  128. }
  129. }
  130. py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) {
  131. py::tuple grads;
  132. if (py::isinstance<py::none>(grads_obj)) {
  133. MS_EXCEPTION(TypeError) << "The 'grads_obj' is none.";
  134. } else if (!py::isinstance<py::tuple>(grads_obj)) {
  135. grads = py::make_tuple(grads_obj);
  136. } else {
  137. grads = py::cast<py::tuple>(grads_obj);
  138. }
  139. if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
  140. return grads;
  141. }
  142. constexpr int filter_args_size = 2;
  143. if (grads.size() != py_args.size() - filter_args_size) {
  144. MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name
  145. << "', the number of return values(gradients) should be equal to the number of input "
  146. "arguments except 'out' and 'dout', which is: "
  147. << (py_args.size() - filter_args_size) << ", but got:" << grads.size() << ".";
  148. }
  149. for (size_t i = 0; i < grads.size(); i++) {
  150. if (py::isinstance<tensor::Tensor>(py_args[i])) {
  151. if (!py::isinstance<tensor::Tensor>(grads[i])) {
  152. MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
  153. << "th return value(gradient of the " << i << "th argument) should be Tensor, but got "
  154. << py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
  155. << ", and the value is " << py::cast<py::str>(grads[i]) << ".";
  156. }
  157. py::object arg_dtype = py_args[i].attr("dtype");
  158. py::object grad_dtype = grads[i].attr("dtype");
  159. py::tuple arg_shape = py_args[i].attr("shape");
  160. py::tuple grad_shape = grads[i].attr("shape");
  161. if (!grad_dtype.equal(arg_dtype)) {
  162. MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
  163. << "th return value(gradient of the " << i
  164. << "th argument) should have the same dtype as the " << i
  165. << "th argument, which is:" << py::cast<py::str>(arg_dtype)
  166. << ", but got: " << py::cast<py::str>(grad_dtype) << ".";
  167. }
  168. if (!grad_shape.equal(arg_shape)) {
  169. MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
  170. << "th return value(gradient of the " << i
  171. << "th argument) should have the same shape as the " << i
  172. << "th argument, which is:" << py::cast<py::str>(arg_shape)
  173. << ", but got: " << py::cast<py::str>(grad_shape) << ".";
  174. }
  175. }
  176. }
  177. return grads;
  178. }
  179. void PrimitivePy::AddBpropCutPrim(const PrimitivePyPtr &bprop_cut_prim) {
  180. MS_EXCEPTION_IF_NULL(bprop_cut_prim);
  181. bprop_cut_prims_.emplace_back(bprop_cut_prim);
  182. }
  183. void PrimitivePy::AddBackwardHookFn(const int &key, const py::function &backward_hook_fn) {
  184. backward_hook_fn_[key] = backward_hook_fn;
  185. for (const auto &elem : bprop_cut_prims_) {
  186. PrimitivePyPtr bprop_cut_prim = elem.lock();
  187. if (bprop_cut_prim != nullptr) {
  188. bprop_cut_prim->AddBackwardHookFn(key, backward_hook_fn);
  189. }
  190. }
  191. }
  192. void PrimitivePy::RemoveBackwardHookFn(const int &key) {
  193. auto iter = backward_hook_fn_.find(key);
  194. if (iter != backward_hook_fn_.end()) {
  195. backward_hook_fn_.erase(key);
  196. }
  197. // Remove hook_fn for bprop cut prim on grad graph.
  198. for (const auto &elem : bprop_cut_prims_) {
  199. PrimitivePyPtr bprop_cut_prim = elem.lock();
  200. if (bprop_cut_prim != nullptr) {
  201. bprop_cut_prim->RemoveBackwardHookFn(key);
  202. }
  203. }
  204. }
  205. void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out,
  206. const py::object &code_obj, const py::object &co_name) const {
  207. if (py::isinstance<py::tuple>(expected_grad_out)) {
  208. if (!py::isinstance<py::tuple>(grad_out)) {
  209. hook_grad_.clear();
  210. MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!";
  211. }
  212. auto actual_out_tuple = py::cast<py::tuple>(grad_out);
  213. auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out);
  214. if (actual_out_tuple.size() != expected_out_tuple.size()) {
  215. hook_grad_.clear();
  216. MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size()
  217. << ", but it is " << actual_out_tuple.size();
  218. }
  219. for (size_t i = 0; i < expected_out_tuple.size(); ++i) {
  220. CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i], code_obj, co_name);
  221. }
  222. }
  223. if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
  224. if (!py::isinstance<tensor::Tensor>(grad_out)) {
  225. hook_grad_.clear();
  226. MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
  227. << py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
  228. }
  229. auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
  230. auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
  231. MS_EXCEPTION_IF_NULL(actual_out_tensor);
  232. MS_EXCEPTION_IF_NULL(expected_out_tensor);
  233. if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
  234. hook_grad_.clear();
  235. MS_EXCEPTION(ValueError) << "The output type of " << py::str(co_name)
  236. << " is not consistent with the expected, it should be "
  237. << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but got "
  238. << actual_out_tensor->GetShapeAndDataTypeInfo();
  239. }
  240. }
  241. }
  242. BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const {
  243. if (backward_hook_fn_.size() > 1) {
  244. MS_LOG(EXCEPTION) << "Multiple registration of bprop function is not supported.";
  245. }
  246. SyncData(py_args);
  247. py::tuple converted_args(py_args.size());
  248. ConvertCTensorToPyTensor(py_args, &converted_args);
  249. constexpr size_t non_inp_args_size = 2; // out and dout.
  250. auto inp_args_size = py_args.size() - non_inp_args_size;
  251. py::tuple input_args(inp_args_size);
  252. for (size_t i = 0; i < inp_args_size; ++i) {
  253. input_args[i] = py_args[i];
  254. }
  255. // Run bprop function.
  256. auto inst = pynative::PynativeExecutor::GetInstance();
  257. MS_EXCEPTION_IF_NULL(inst);
  258. try {
  259. MS_LOG(DEBUG) << "Run bprop function start.";
  260. py::tuple grads;
  261. for (const auto &elem : backward_hook_fn_) {
  262. inst->NewGraph(elem.second, input_args.cast<py::args>());
  263. py::object grads_obj = elem.second(*converted_args);
  264. grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
  265. inst->EndGraph(elem.second, grads_obj, input_args.cast<py::args>());
  266. }
  267. MS_LOG(DEBUG) << "Run bprop function end.";
  268. return std::make_shared<PyObjectRef>(grads);
  269. } catch (std::exception &bt) {
  270. inst->ClearRes();
  271. std::rethrow_exception(std::current_exception());
  272. }
  273. }
  274. BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
  275. // Get the gradient passed to current bprop cut op.
  276. const auto args_size = py_args.size();
  277. py::object grad_output = py_args[args_size - 1];
  278. // Get the cell id.
  279. auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
  280. auto iter = hook_grad_.find(cell_id);
  281. if (iter != hook_grad_.end()) {
  282. // The second bprop_cut used to hook output gradient of cell.
  283. for (const auto &elem : backward_hook_fn_) {
  284. py::object code_obj = py::getattr(elem.second, "__code__");
  285. py::object co_name = py::getattr(code_obj, "co_name");
  286. if (std::string(py::str(co_name)) == "staging_specialize") {
  287. py::object name_obj = py::getattr(elem.second, "__name__");
  288. MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj)
  289. << " with '@ms_function' is not supported.";
  290. }
  291. SyncData(grad_output);
  292. py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, iter->second, grad_output);
  293. py::object ret = elem.second(*hook_fn_args);
  294. if (!py::isinstance<py::none>(ret)) {
  295. grad_output = ret;
  296. }
  297. CheckHookConsistency(grad_output, py_args[args_size - 1], code_obj, co_name);
  298. }
  299. (void)hook_grad_.erase(cell_id);
  300. } else {
  301. // The first bprop_cut used to hook input gradient of cell.
  302. SyncData(grad_output);
  303. hook_grad_[cell_id] = grad_output;
  304. }
  305. if (!py::isinstance<py::tuple>(grad_output)) {
  306. grad_output = py::make_tuple(grad_output);
  307. }
  308. return std::make_shared<PyObjectRef>(grad_output);
  309. }
  310. BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
  311. constexpr size_t grad_output_index = 2;
  312. py::object grad_output = py_args[grad_output_index];
  313. for (const auto &elem : backward_hook_fn_) {
  314. py::object code_obj = py::getattr(elem.second, "__code__");
  315. py::object co_name = py::getattr(code_obj, "co_name");
  316. if (std::string(py::str(co_name)) == "staging_specialize") {
  317. py::object name_obj = py::getattr(elem.second, "__name__");
  318. MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported.";
  319. }
  320. SyncData(grad_output);
  321. py::object ret = elem.second(py::make_tuple(grad_output));
  322. if (!py::isinstance<py::none>(ret)) {
  323. grad_output = ret;
  324. }
  325. CheckHookConsistency(grad_output, py_args[grad_output_index], code_obj, co_name);
  326. }
  327. grad_output = py::make_tuple(grad_output);
  328. return std::make_shared<PyObjectRef>(grad_output);
  329. }
  330. BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
  331. py::tuple py_args = ConvertDatatoPyTuple(args);
  332. bool is_bprop = this->HasAttr(kBpropAttrName);
  333. if (is_bprop) {
  334. return RunCellBpropFunction(py_args);
  335. }
  336. bool is_cell = this->HasAttr(kCellHookAttrName);
  337. if (is_cell) {
  338. return RunCellHookFunction(py_args);
  339. }
  340. return RunVariableHookFunction(py_args);
  341. }
  342. py::function PrimitivePy::GetComputeFunction() const {
  343. static const char *const compute_func_name = "vm_impl";
  344. if (py::hasattr(python_obj_, compute_func_name)) {
  345. MS_LOG(DEBUG) << name() << " compute_func_name";
  346. py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
  347. return fn;
  348. }
  349. static const std::string vm_module = "mindspore.ops.vm_impl_registry";
  350. static const std::string get_vm_impl_fn = "get_vm_impl_fn";
  351. MS_LOG(DEBUG) << name() << ": get_vm_impl_fn";
  352. py::function get_fn = python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
  353. py::function vm_fn = get_fn(python_obj_);
  354. if (py::isinstance<py::none>(vm_fn)) {
  355. MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
  356. vm_fn = mindspore::GetComputeFunction(Primitive::name());
  357. }
  358. return vm_fn;
  359. }
  360. py::dict PrimitivePy::GetAttrDict() {
  361. py::dict attr_dict;
  362. for (auto &attr : attrs_) {
  363. attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
  364. }
  365. return attr_dict;
  366. }
  367. void PrimitivePy::CopyHookFunction(const PrimitivePyPtr &primitive_py) {
  368. MS_EXCEPTION_IF_NULL(primitive_py);
  369. const auto &backward_hook_fn = primitive_py->backward_hook_fn();
  370. for (const auto &elem : backward_hook_fn) {
  371. AddBackwardHookFn(elem.first, elem.second);
  372. }
  373. if (primitive_py->HasAttr(kBpropAttrName)) {
  374. set_bprop_cls_name(primitive_py->bprop_cls_name_);
  375. (void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
  376. }
  377. }
  378. BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
  379. auto py_args = ConvertDatatoPyTuple(args);
  380. auto result = this->RunPyComputeFunction(py_args);
  381. if (py::isinstance<py::none>(result)) {
  382. return std::make_shared<BaseRef>(nullptr);
  383. }
  384. return std::make_shared<PyObjectRef>(result);
  385. }
  386. py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
  387. auto func = this->GetComputeFunction();
  388. if (py::isinstance<py::none>(func)) {
  389. return py::none();
  390. }
  391. auto result = func(*py_args);
  392. return result;
  393. }
  394. bool PrimitivePy::HasComputeFunction() const {
  395. auto func = GetComputeFunction();
  396. return !py::isinstance<py::none>(func);
  397. }
  398. PrimitivePtr PrimitivePy::Clone() {
  399. auto clone_fn = python_obj_.attr("_clone");
  400. py::object obj_adapter = clone_fn();
  401. auto prim_adapter = obj_adapter.cast<PrimitivePyAdapterPtr>();
  402. auto prim = std::make_shared<PrimitivePy>(obj_adapter, prim_adapter);
  403. prim_adapter->set_attached_primitive(prim);
  404. return prim;
  405. }
  406. py::dict PrimitivePy::RunInfer(const py::tuple &args) {
  407. if (!HasPyObj()) {
  408. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  409. }
  410. // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
  411. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER)) {
  412. MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER;
  413. }
  414. auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER);
  415. return infer_fuc(*args);
  416. }
  417. void PrimitivePy::RunCheck(const py::tuple &args) {
  418. if (!HasPyObj()) {
  419. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  420. }
  421. // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
  422. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_CHECK)) {
  423. MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_CHECK;
  424. }
  425. auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK);
  426. (void)check_func(*args);
  427. }
  428. py::object PrimitivePy::RunInferValue(const py::tuple &args) {
  429. if (!HasPyObj()) {
  430. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  431. }
  432. // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
  433. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER_VALUE)) {
  434. MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER_VALUE;
  435. }
  436. auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE);
  437. return infer_value(*args);
  438. }
  439. void PrimitivePy::ClearHookRes() { hook_grad_.clear(); }
  440. PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : name_(name) {}
  441. void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
  442. std::string attr_name = name;
  443. ValuePtr converted_ret = nullptr;
  444. if (py::isinstance<py::module>(obj)) {
  445. MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed,"
  446. << " not support py::module to be attribute value; primitive name: " << this->name_
  447. << ", attribute name: " << attr_name << " attribute value: " << py::str(obj);
  448. }
  449. bool converted = parse::ConvertData(obj, &converted_ret);
  450. if (!converted) {
  451. MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed,"
  452. << " convert python obj to MindSpore obj failed; primitive name: " << this->name_
  453. << ", attribute name:" << attr_name << ", attribute value:" << py::str(obj)
  454. << ", attribute type:" << py::cast<std::string>(obj.attr("__class__").attr("__name__"));
  455. }
  456. if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
  457. attr_name = kOpAttrNameReplaceMap[attr_name];
  458. }
  459. (void)CheckAndConvertUtils::ConvertAttrValueToInt(this->name_, name, &converted_ret);
  460. if (attr_name == "primitive_target") {
  461. MS_EXCEPTION_IF_NULL(converted_ret);
  462. if (!converted_ret->isa<StringImm>()) {
  463. MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive '" << this->name_
  464. << "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got "
  465. << py::str(obj);
  466. }
  467. auto target = GetValue<std::string>(converted_ret);
  468. if (!target.empty() && target != kCPUDevice && target != kGPUDevice && target != kAscendDevice &&
  469. target != "Device") {
  470. MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive '" << this->name_
  471. << "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got "
  472. << py::str(obj);
  473. }
  474. }
  475. attrs_[attr_name] = converted_ret;
  476. auto prim = attached_primitive_.lock();
  477. if (prim != nullptr) {
  478. (void)prim->AddAttr(attr_name, converted_ret);
  479. }
  480. }
  481. void PrimitivePyAdapter::DelPyAttr(const py::str &name) {
  482. (void)attrs_.erase(name);
  483. auto prim = attached_primitive_.lock();
  484. if (prim != nullptr) {
  485. (void)prim->DelAttr(name);
  486. }
  487. }
  488. py::dict PrimitivePyAdapter::GetAttrDict() {
  489. auto prim = attached_primitive_.lock();
  490. if (prim != nullptr) {
  491. return prim->GetAttrDict();
  492. }
  493. py::dict attr_dict;
  494. for (auto &attr : attrs_) {
  495. attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
  496. }
  497. return attr_dict;
  498. }
  499. void PrimitivePyAdapter::set_prim_type(const PrimType t) {
  500. prim_type_ = t;
  501. auto prim = attached_primitive_.lock();
  502. if (prim != nullptr) {
  503. prim->set_prim_type(t);
  504. }
  505. }
  506. void PrimitivePyAdapter::set_const_prim(bool is_const_prim) {
  507. is_const_prim_ = is_const_prim;
  508. auto prim = attached_primitive_.lock();
  509. if (prim != nullptr) {
  510. prim->set_const_prim(is_const_prim);
  511. }
  512. }
  513. void PrimitivePyAdapter::set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
  514. const_input_indexes_ = const_input_indexes;
  515. auto prim = attached_primitive_.lock();
  516. if (prim != nullptr) {
  517. prim->set_const_input_indexes(const_input_indexes);
  518. }
  519. }
  520. void PrimitivePyAdapter::set_signatures(const std::vector<Signature> &signatures) {
  521. signatures_ = signatures;
  522. auto prim = attached_primitive_.lock();
  523. if (prim != nullptr) {
  524. prim->set_signatures(signatures);
  525. }
  526. }
  527. int PrimitivePyAdapter::AddBackwardHookFn(const py::function &backward_hook_fn) {
  528. ++backward_hook_fn_key_;
  529. backward_hook_fn_[backward_hook_fn_key_] = backward_hook_fn;
  530. auto prim = attached_primitive_.lock();
  531. if (prim != nullptr) {
  532. prim->AddBackwardHookFn(backward_hook_fn_key_, backward_hook_fn);
  533. }
  534. return backward_hook_fn_key_;
  535. }
  536. void PrimitivePyAdapter::RemoveBackwardHookFn(int key) {
  537. auto iter = backward_hook_fn_.find(key);
  538. if (iter != backward_hook_fn_.end()) {
  539. backward_hook_fn_.erase(iter);
  540. }
  541. auto prim = attached_primitive_.lock();
  542. if (prim != nullptr) {
  543. prim->RemoveBackwardHookFn(key);
  544. }
  545. }
  546. void PrimitivePyAdapter::set_instance_name(const std::string &s) {
  547. instance_name_ = s;
  548. auto prim = attached_primitive_.lock();
  549. if (prim != nullptr) {
  550. prim->set_instance_name(s);
  551. }
  552. }
  553. void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) {
  554. if (attached_primitive_.lock() != nullptr) {
  555. MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive.";
  556. }
  557. MS_EXCEPTION_IF_NULL(prim);
  558. attached_primitive_ = prim;
  559. }
  560. REGISTER_PYBIND_DEFINE(
  561. Primitive_, ([](const py::module *m) {
  562. (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
  563. .value("unknown", PrimType::kPrimTypeUnknown)
  564. .value("builtin", PrimType::kPrimTypeBuiltIn)
  565. .value("py_infer_shape", PrimType::kPrimTypePyInfer)
  566. .value("user_custom", PrimType::kPrimTypeUserCustom)
  567. .value("py_infer_check", PrimType::kPrimTypePyCheck);
  568. (void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
  569. .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
  570. .def(py::init<py::str &>())
  571. .def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr")
  572. .def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr")
  573. .def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr")
  574. .def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.")
  575. .def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.")
  576. .def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes,
  577. "Set primitive const input indexes.")
  578. .def("set_signatures", &PrimitivePyAdapter::set_signatures, "Set primitive inputs signature.")
  579. .def("add_backward_hook_fn", &PrimitivePyAdapter::AddBackwardHookFn, "Add primitive backward hook function.")
  580. .def("remove_backward_hook_fn", &PrimitivePyAdapter::RemoveBackwardHookFn,
  581. "Remove primitive backward hook function.")
  582. .def("set_instance_name", &PrimitivePyAdapter::set_instance_name, "Set primitive instance name.");
  583. }));
  584. } // namespace mindspore