You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

primitive_py.cc 21 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pybind_api/ir/primitive_py.h"
  17. #include <mutex>
  18. #include <map>
  19. #include <utility>
  20. #include "ir/signature.h"
  21. #include "pipeline/jit/parse/data_converter.h"
  22. #include "pipeline/jit/parse/python_adapter.h"
  23. #include "pybind11/pytypes.h"
  24. #include "pybind_api/api_register.h"
  25. #include "pybind_api/export_flags.h"
  26. #include "pybind_api/ir/base_ref_py.h"
  27. #include "utils/convert_utils_base.h"
  28. #include "utils/convert_utils_py.h"
  29. #include "utils/ms_context.h"
  30. #include "utils/primitive_utils.h"
  31. #include "utils/check_convert_utils.h"
  32. #include "pipeline/jit/resource.h"
  33. #include "pipeline/pynative/pynative_execute.h"
  34. namespace mindspore {
  35. namespace {
  36. constexpr auto kBpropAttrName = "bprop";
  37. constexpr auto kCellHookAttrName = "cell_hook";
  38. constexpr auto kCellIDAttrName = "cell_id";
  39. std::map<std::string, std::string> kOpAttrNameReplaceMap = {
  40. {"data_format", "format"},
  41. };
  42. void SyncData(const py::object &arg) {
  43. if (py::isinstance<py::tuple>(arg)) {
  44. py::tuple arg_list = py::cast<py::tuple>(arg);
  45. for (size_t i = 0; i < arg_list.size(); i++) {
  46. SyncData(arg_list[i]);
  47. }
  48. }
  49. if (py::isinstance<tensor::Tensor>(arg)) {
  50. auto tensor = py::cast<tensor::TensorPtr>(arg);
  51. tensor->data_sync();
  52. }
  53. }
  54. } // namespace
  55. std::map<std::string, py::object> PrimitivePy::hook_grad_;
  56. PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {}
  57. PrimitivePy::PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter)
  58. : Primitive(adapter->name_, false), python_obj_(python_obj), adapter_(adapter) {
  59. MS_LOG(DEBUG) << "New primitive:" << adapter->name_;
  60. set_signatures(adapter->signatures_);
  61. (void)Primitive::SetAttrs(adapter->attrs_);
  62. Primitive::set_prim_type(adapter->prim_type_);
  63. Primitive::set_const_prim(adapter->is_const_prim_);
  64. Primitive::set_const_input_indexes(adapter->const_input_indexes_);
  65. set_hook(adapter->hook_);
  66. set_instance_name(adapter->instance_name_);
  67. }
  68. PrimitivePy::~PrimitivePy() { MS_LOG(DEBUG) << "Release:" << ToString(); }
  69. void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
  70. signatures_ = signatures;
  71. set_has_signature(!signatures.empty());
  72. }
  73. py::function PrimitivePy::GetBpropFunction() {
  74. static const char *const get_bprop_func_name = "get_bprop";
  75. if (py::hasattr(python_obj_, get_bprop_func_name)) {
  76. py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
  77. return fn;
  78. } else {
  79. auto fn = GetBpropFunctionByObj(python_obj_);
  80. return fn;
  81. }
  82. }
  83. py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) {
  84. py::tuple grads;
  85. if (!py::isinstance<py::tuple>(grads_obj)) {
  86. grads = py::make_tuple(grads_obj);
  87. } else {
  88. grads = py::cast<py::tuple>(grads_obj);
  89. }
  90. constexpr int filter_args_size = 2;
  91. if (grads.size() != py_args.size() - filter_args_size) {
  92. MS_EXCEPTION(TypeError) << "For user define net bprop, the gradients number: " << grads.size()
  93. << " is not equal to the args number: " << (py_args.size() - filter_args_size) << ".";
  94. }
  95. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
  96. for (size_t i = 0; i < grads.size(); i++) {
  97. if (py::isinstance<tensor::Tensor>(py_args[i])) {
  98. if (!py::isinstance<tensor::Tensor>(grads[i])) {
  99. MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i
  100. << "th arg should be Tensor, but got "
  101. << py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
  102. << ", and the value is " << py::cast<py::str>(grads[i]) << ".";
  103. }
  104. py::object arg_dtype = py_args[i].attr("dtype");
  105. py::object grad_dtype = grads[i].attr("dtype");
  106. py::tuple arg_shape = py_args[i].attr("shape");
  107. py::tuple grad_shape = grads[i].attr("shape");
  108. if (!grad_dtype.equal(arg_dtype)) {
  109. MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
  110. << "th arg should have the same dtype as the " << i << "th arg, but the " << i
  111. << "th arg dtype is: " << py::cast<py::str>(arg_dtype)
  112. << ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
  113. }
  114. if (!grad_shape.equal(arg_shape)) {
  115. MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
  116. << "th arg should have the same shape as the " << i << "th arg, but the " << i
  117. << "th arg shape is: " << py::cast<py::str>(arg_shape)
  118. << ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
  119. }
  120. }
  121. }
  122. }
  123. return grads;
  124. }
  125. void PrimitivePy::ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const {
  126. MS_EXCEPTION_IF_NULL(convert_args);
  127. if (input_args.size() != (*convert_args).size()) {
  128. MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size()
  129. << " should be equal to the size of convert_args: " << (*convert_args).size();
  130. }
  131. for (size_t i = 0; i < input_args.size(); ++i) {
  132. (*convert_args)[i] = py::isinstance<tensor::Tensor>(input_args[i])
  133. ? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
  134. parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i])
  135. : input_args[i];
  136. }
  137. }
  138. void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const {
  139. if (py::isinstance<py::tuple>(expected_grad_out)) {
  140. if (!py::isinstance<py::tuple>(grad_out)) {
  141. hook_grad_.clear();
  142. MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!";
  143. }
  144. auto actual_out_tuple = py::cast<py::tuple>(grad_out);
  145. auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out);
  146. if (actual_out_tuple.size() != expected_out_tuple.size()) {
  147. hook_grad_.clear();
  148. MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size()
  149. << ", but it is " << actual_out_tuple.size();
  150. }
  151. for (size_t i = 0; i < expected_out_tuple.size(); ++i) {
  152. CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i]);
  153. }
  154. }
  155. if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
  156. if (!py::isinstance<tensor::Tensor>(grad_out)) {
  157. hook_grad_.clear();
  158. MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!";
  159. }
  160. auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
  161. auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
  162. MS_EXCEPTION_IF_NULL(actual_out_tensor);
  163. MS_EXCEPTION_IF_NULL(expected_out_tensor);
  164. if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
  165. hook_grad_.clear();
  166. MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be "
  167. << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is "
  168. << actual_out_tensor->GetShapeAndDataTypeInfo();
  169. }
  170. }
  171. }
  172. BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const {
  173. SyncData(py_args);
  174. auto size = py_args.size();
  175. constexpr size_t grad_param_nums = 2;
  176. py::tuple input_args(size - grad_param_nums);
  177. for (size_t i = 0; i < size - grad_param_nums; ++i) {
  178. input_args[i] = py_args[i];
  179. }
  180. py::tuple convert_args(py_args.size());
  181. ConvertCTensorToPyTensor(py_args, &convert_args);
  182. auto inst = pynative::PynativeExecutor::GetInstance();
  183. MS_EXCEPTION_IF_NULL(inst);
  184. try {
  185. MS_LOG(DEBUG) << "Run bprop function start";
  186. inst->NewGraph(hook_, input_args.cast<py::args>());
  187. py::object grads_obj = hook_(*convert_args);
  188. py::tuple grads = check_bprop_out(grads_obj, py_args);
  189. inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
  190. MS_LOG(DEBUG) << "Run bprop function end";
  191. return std::make_shared<PyObjectRef>(grads);
  192. } catch (std::exception &bt) {
  193. inst->ClearRes();
  194. std::rethrow_exception(std::current_exception());
  195. }
  196. }
  197. BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
  198. constexpr size_t grad_input_index = 1;
  199. constexpr size_t grad_output_index = 2;
  200. constexpr size_t input_param_nums = 3;
  201. SyncData(py_args[grad_output_index]);
  202. py::object obj;
  203. auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
  204. auto iter = hook_grad_.find(cell_id);
  205. if (iter != hook_grad_.end()) {
  206. py::tuple convert_args(input_param_nums - 1);
  207. py::tuple input_args(input_param_nums - 1);
  208. input_args[0] = iter->second;
  209. input_args[1] = py_args[grad_output_index];
  210. ConvertCTensorToPyTensor(input_args, &convert_args);
  211. auto hook_args = py::tuple(input_param_nums);
  212. hook_args[0] = cell_id;
  213. hook_args[grad_input_index] = py::make_tuple(convert_args[0]);
  214. hook_args[grad_output_index] = py::make_tuple(convert_args[1]);
  215. obj = hook_(*hook_args);
  216. if (py::isinstance<py::none>(obj)) {
  217. obj = py_args[grad_output_index];
  218. }
  219. CheckHookConsistency(obj, py_args[grad_output_index]);
  220. (void)hook_grad_.erase(cell_id);
  221. } else {
  222. hook_grad_[cell_id] = py_args[grad_output_index];
  223. obj = py_args[grad_output_index];
  224. }
  225. obj = py::make_tuple(obj);
  226. return std::make_shared<PyObjectRef>(obj);
  227. }
  228. BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
  229. constexpr size_t grad_output_index = 2;
  230. SyncData(py_args[grad_output_index]);
  231. py::object obj = hook_(py::make_tuple(py_args[grad_output_index]));
  232. if (py::isinstance<py::none>(obj)) {
  233. obj = py_args[grad_output_index];
  234. }
  235. CheckHookConsistency(obj, py_args[grad_output_index]);
  236. obj = py::make_tuple(obj);
  237. return std::make_shared<PyObjectRef>(obj);
  238. }
  239. BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
  240. py::tuple py_args = ConvertDatatoPyTuple(args);
  241. bool is_bprop = this->HasAttr(kBpropAttrName);
  242. if (is_bprop) {
  243. return RunCellBpropFunction(py_args);
  244. }
  245. bool is_cell = this->HasAttr(kCellHookAttrName);
  246. if (is_cell) {
  247. return RunCellHookFunction(py_args);
  248. }
  249. return RunVariableHookFunction(py_args);
  250. }
  251. py::function PrimitivePy::GetComputeFunction() const {
  252. static const char *const compute_func_name = "vm_impl";
  253. if (py::hasattr(python_obj_, compute_func_name)) {
  254. MS_LOG(DEBUG) << name() << " compute_func_name";
  255. py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
  256. return fn;
  257. }
  258. static const std::string vm_module = "mindspore.ops.vm_impl_registry";
  259. static const std::string get_vm_impl_fn = "get_vm_impl_fn";
  260. MS_LOG(DEBUG) << name() << ": get_vm_impl_fn";
  261. py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
  262. py::function vm_fn = get_fn(python_obj_);
  263. if (py::isinstance<py::none>(vm_fn)) {
  264. MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
  265. vm_fn = mindspore::GetComputeFunction(Primitive::name());
  266. }
  267. return vm_fn;
  268. }
  269. py::dict PrimitivePy::GetAttrDict() {
  270. py::dict attr_dict;
  271. for (auto &attr : attrs_) {
  272. attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
  273. }
  274. return attr_dict;
  275. }
  276. void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
  277. MS_EXCEPTION_IF_NULL(primitive);
  278. if (!primitive->isa<PrimitivePy>()) {
  279. MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!";
  280. }
  281. auto primitive_py = primitive->cast<PrimitivePyPtr>();
  282. MS_EXCEPTION_IF_NULL(primitive_py);
  283. this->set_hook(primitive_py->hook());
  284. if (primitive_py->HasAttr(kBpropAttrName)) {
  285. (void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
  286. }
  287. }
  288. BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
  289. auto py_args = ConvertDatatoPyTuple(args);
  290. auto result = this->RunPyComputeFunction(py_args);
  291. if (py::isinstance<py::none>(result)) {
  292. return std::make_shared<BaseRef>(nullptr);
  293. }
  294. return std::make_shared<PyObjectRef>(result);
  295. }
  296. py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
  297. auto func = this->GetComputeFunction();
  298. if (py::isinstance<py::none>(func)) {
  299. return py::none();
  300. }
  301. auto result = func(*py_args);
  302. return result;
  303. }
  304. bool PrimitivePy::HasComputeFunction() const {
  305. auto func = GetComputeFunction();
  306. return !py::isinstance<py::none>(func);
  307. }
  308. PrimitivePtr PrimitivePy::Clone() {
  309. auto clone_fn = python_obj_.attr("_clone");
  310. py::object obj_adapter = clone_fn();
  311. auto prim_adapter = obj_adapter.cast<PrimitivePyAdapterPtr>();
  312. auto prim = std::make_shared<PrimitivePy>(obj_adapter, prim_adapter);
  313. prim_adapter->set_attached_primitive(prim);
  314. return prim;
  315. }
  316. py::dict PrimitivePy::RunInfer(const py::tuple &args) {
  317. if (!HasPyObj()) {
  318. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  319. }
  320. // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
  321. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER)) {
  322. MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER;
  323. }
  324. auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER);
  325. return infer_fuc(*args);
  326. }
  327. void PrimitivePy::RunCheck(const py::tuple &args) {
  328. if (!HasPyObj()) {
  329. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  330. }
  331. // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
  332. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_CHECK)) {
  333. MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_CHECK;
  334. }
  335. auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK);
  336. (void)check_func(*args);
  337. }
  338. py::object PrimitivePy::RunInferValue(const py::tuple &args) {
  339. if (!HasPyObj()) {
  340. MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
  341. }
  342. // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
  343. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER_VALUE)) {
  344. MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER_VALUE;
  345. }
  346. auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE);
  347. return infer_value(*args);
  348. }
  349. PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : name_(name) {}
  350. void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
  351. std::string attr_name = name;
  352. ValuePtr converted_ret = nullptr;
  353. if (py::isinstance<py::module>(obj)) {
  354. MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
  355. }
  356. bool converted = parse::ConvertData(obj, &converted_ret);
  357. if (!converted) {
  358. MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
  359. }
  360. if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
  361. attr_name = kOpAttrNameReplaceMap[attr_name];
  362. }
  363. (void)CheckAndConvertUtils::ConvertAttrValueToInt(name_, name, &converted_ret);
  364. attrs_[attr_name] = converted_ret;
  365. auto prim = attached_primitive_.lock();
  366. if (prim != nullptr) {
  367. (void)prim->AddAttr(attr_name, converted_ret);
  368. }
  369. if (attr_name == "primitive_target") {
  370. MS_EXCEPTION_IF_NULL(converted_ret);
  371. if (!converted_ret->isa<StringImm>()) {
  372. MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
  373. }
  374. auto target = GetValue<std::string>(converted_ret);
  375. if (target != kCPUDevice && target != kGPUDevice) {
  376. auto context_ptr = MsContext::GetInstance();
  377. MS_EXCEPTION_IF_NULL(context_ptr);
  378. context_ptr->set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, true);
  379. }
  380. }
  381. }
  382. void PrimitivePyAdapter::DelPyAttr(const py::str &name) {
  383. (void)attrs_.erase(name);
  384. auto prim = attached_primitive_.lock();
  385. if (prim != nullptr) {
  386. (void)prim->DelAttr(name);
  387. }
  388. }
  389. py::dict PrimitivePyAdapter::GetAttrDict() {
  390. auto prim = attached_primitive_.lock();
  391. if (prim != nullptr) {
  392. return prim->GetAttrDict();
  393. }
  394. py::dict attr_dict;
  395. for (auto &attr : attrs_) {
  396. attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
  397. }
  398. return attr_dict;
  399. }
  400. void PrimitivePyAdapter::set_prim_type(const PrimType t) {
  401. prim_type_ = t;
  402. auto prim = attached_primitive_.lock();
  403. if (prim != nullptr) {
  404. prim->set_prim_type(t);
  405. }
  406. }
  407. void PrimitivePyAdapter::set_const_prim(bool is_const_prim) {
  408. is_const_prim_ = is_const_prim;
  409. auto prim = attached_primitive_.lock();
  410. if (prim != nullptr) {
  411. prim->set_const_prim(is_const_prim);
  412. }
  413. }
  414. void PrimitivePyAdapter::set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
  415. const_input_indexes_ = const_input_indexes;
  416. auto prim = attached_primitive_.lock();
  417. if (prim != nullptr) {
  418. prim->set_const_input_indexes(const_input_indexes);
  419. }
  420. }
  421. void PrimitivePyAdapter::set_signatures(const std::vector<Signature> &signatures) {
  422. signatures_ = signatures;
  423. auto prim = attached_primitive_.lock();
  424. if (prim != nullptr) {
  425. prim->set_signatures(signatures);
  426. }
  427. }
  428. void PrimitivePyAdapter::set_hook(const py::function &hook) {
  429. hook_ = hook;
  430. auto prim = attached_primitive_.lock();
  431. if (prim != nullptr) {
  432. prim->set_hook(hook);
  433. }
  434. }
  435. void PrimitivePyAdapter::set_instance_name(const std::string &s) {
  436. instance_name_ = s;
  437. auto prim = attached_primitive_.lock();
  438. if (prim != nullptr) {
  439. prim->set_instance_name(s);
  440. }
  441. }
  442. void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) {
  443. if (attached_primitive_.lock() != nullptr) {
  444. MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive.";
  445. }
  446. MS_EXCEPTION_IF_NULL(prim);
  447. attached_primitive_ = prim;
  448. }
  449. REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
  450. (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
  451. .value("unknown", PrimType::kPrimTypeUnknown)
  452. .value("builtin", PrimType::kPrimTypeBuiltIn)
  453. .value("py_infer_shape", PrimType::kPrimTypePyInfer)
  454. .value("user_custom", PrimType::kPrimTypeUserCustom)
  455. .value("py_infer_check", PrimType::kPrimTypePyCheck);
  456. (void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
  457. .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
  458. .def(py::init<py::str &>())
  459. .def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr")
  460. .def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr")
  461. .def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr")
  462. .def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.")
  463. .def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.")
  464. .def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes,
  465. "Set primitive const input indexes.")
  466. .def("set_signatures", &PrimitivePyAdapter::set_signatures,
  467. "Set primitive inputs signature.")
  468. .def("register_hook", &PrimitivePyAdapter::set_hook, "Set primitive hook function.")
  469. .def("set_instance_name", &PrimitivePyAdapter::set_instance_name,
  470. "Set primitive instance name.");
  471. }));
  472. } // namespace mindspore