/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pybind_api/ir/primitive_py.h" #include #include #include #include "ir/signature.h" #include "pipeline/jit/parse/data_converter.h" #include "include/common/utils/python_adapter.h" #include "pybind11/pytypes.h" #include "pybind_api/api_register.h" #include "pybind_api/export_flags.h" #include "pybind_api/ir/base_ref_py.h" #include "utils/convert_utils_base.h" #include "include/common/utils/convert_utils_py.h" #include "utils/ms_context.h" #include "include/common/utils/primitive_utils.h" #include "utils/check_convert_utils.h" #include "pipeline/pynative/pynative_execute.h" namespace mindspore { namespace { constexpr auto kBpropAttrName = "bprop"; constexpr auto kCellHookAttrName = "cell_hook"; constexpr auto kCellIDAttrName = "cell_id"; std::map kOpAttrNameReplaceMap = { {"data_format", "format"}, }; void SyncData(const py::object &arg) { if (py::isinstance(arg)) { py::tuple arg_list = py::cast(arg); for (size_t i = 0; i < arg_list.size(); i++) { SyncData(arg_list[i]); } } if (py::isinstance(arg)) { auto tensor = py::cast(arg); tensor->data_sync(); } } void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) { MS_EXCEPTION_IF_NULL(convert_args); if (input_args.size() != (*convert_args).size()) { MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size() << " should be equal to the size of convert_args: " << (*convert_args).size(); } for (size_t i = 0; i < input_args.size(); ++i) { if (py::isinstance(input_args[i])) { (*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]); } else if (py::isinstance(input_args[i])) { auto tuple_inp_arg = py::cast(input_args[i]); py::tuple convert_tuple_arg(tuple_inp_arg.size()); ConvertCTensorToPyTensor(tuple_inp_arg, &convert_tuple_arg); (*convert_args)[i] = convert_tuple_arg; } else { (*convert_args)[i] = input_args[i]; } } } py::tuple ConstructCellHookFnArgs(const std::string &cell_id, const py::object &grad_input, const py::object &grad_output) { constexpr size_t grad_input_index = 1; constexpr size_t grad_output_index = 2; constexpr size_t input_args_nums = 3; // Convert c++ object to python object. py::tuple c_grad_args(input_args_nums - 1); c_grad_args[0] = grad_input; c_grad_args[1] = grad_output; py::tuple py_grad_args(input_args_nums - 1); ConvertCTensorToPyTensor(c_grad_args, &py_grad_args); // Get tuple args of cell hook function. py::tuple hook_fn_args(input_args_nums); hook_fn_args[0] = cell_id; if (!py::isinstance(py_grad_args[0])) { hook_fn_args[grad_input_index] = py::make_tuple(py_grad_args[0]); } else { hook_fn_args[grad_input_index] = py_grad_args[0]; } if (!py::isinstance(py_grad_args[1])) { hook_fn_args[grad_output_index] = py::make_tuple(py_grad_args[1]); } else { hook_fn_args[grad_output_index] = py_grad_args[1]; } return hook_fn_args; } } // namespace std::map PrimitivePy::hook_grad_; PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {} PrimitivePy::PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter) : Primitive(adapter->name_, false), python_obj_(python_obj), adapter_(adapter) { MS_LOG(DEBUG) << "New primitive:" << adapter->name_; set_signatures(adapter->signatures_); (void)Primitive::SetAttrs(adapter->attrs_); Primitive::set_prim_type(adapter->prim_type_); Primitive::set_const_prim(adapter->is_const_prim_); Primitive::set_const_input_indexes(adapter->const_input_indexes_); for (const auto &elem : adapter->backward_hook_fn_) { AddBackwardHookFn(elem.first, elem.second); } set_instance_name(adapter->instance_name_); } PrimitivePy::~PrimitivePy() {} void PrimitivePy::set_signatures(const std::vector &signatures) { signatures_ = signatures; set_has_signature(!signatures.empty()); } py::function PrimitivePy::GetBpropFunction() { static const char *const get_bprop_func_name = "get_bprop"; if (py::hasattr(python_obj_, get_bprop_func_name)) { py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); return fn; } else { auto fn = GetBpropFunctionByObj(python_obj_); return fn; } } py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) { py::tuple grads; if (py::isinstance(grads_obj)) { MS_EXCEPTION(TypeError) << "The 'grads_obj' is none."; } else if (!py::isinstance(grads_obj)) { grads = py::make_tuple(grads_obj); } else { grads = py::cast(grads_obj); } if (!MsContext::GetInstance()->get_param(MS_CTX_CHECK_BPROP_FLAG)) { return grads; } constexpr int filter_args_size = 2; if (grads.size() != py_args.size() - filter_args_size) { MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the number of return values(gradients) should be equal to the number of input " "arguments except 'out' and 'dout', which is: " << (py_args.size() - filter_args_size) << ", but got:" << grads.size() << "."; } for (size_t i = 0; i < grads.size(); i++) { if (py::isinstance(py_args[i])) { if (!py::isinstance(grads[i])) { MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i << "th return value(gradient of the " << i << "th argument) should be Tensor, but got " << py::cast(grads[i].attr("__class__").attr("__name__")) << ", and the value is " << py::cast(grads[i]) << "."; } py::object arg_dtype = py_args[i].attr("dtype"); py::object grad_dtype = grads[i].attr("dtype"); py::tuple arg_shape = py_args[i].attr("shape"); py::tuple grad_shape = grads[i].attr("shape"); if (!grad_dtype.equal(arg_dtype)) { MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i << "th return value(gradient of the " << i << "th argument) should have the same dtype as the " << i << "th argument, which is:" << py::cast(arg_dtype) << ", but got: " << py::cast(grad_dtype) << "."; } if (!grad_shape.equal(arg_shape)) { MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i << "th return value(gradient of the " << i << "th argument) should have the same shape as the " << i << "th argument, which is:" << py::cast(arg_shape) << ", but got: " << py::cast(grad_shape) << "."; } } } return grads; } void PrimitivePy::AddBpropCutPrim(const PrimitivePyPtr &bprop_cut_prim) { MS_EXCEPTION_IF_NULL(bprop_cut_prim); bprop_cut_prims_.emplace_back(bprop_cut_prim); } void PrimitivePy::AddBackwardHookFn(const int &key, const py::function &backward_hook_fn) { backward_hook_fn_[key] = backward_hook_fn; for (const auto &elem : bprop_cut_prims_) { PrimitivePyPtr bprop_cut_prim = elem.lock(); if (bprop_cut_prim != nullptr) { bprop_cut_prim->AddBackwardHookFn(key, backward_hook_fn); } } } void PrimitivePy::RemoveBackwardHookFn(const int &key) { auto iter = backward_hook_fn_.find(key); if (iter != backward_hook_fn_.end()) { backward_hook_fn_.erase(key); } // Remove hook_fn for bprop cut prim on grad graph. for (const auto &elem : bprop_cut_prims_) { PrimitivePyPtr bprop_cut_prim = elem.lock(); if (bprop_cut_prim != nullptr) { bprop_cut_prim->RemoveBackwardHookFn(key); } } } void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out, const py::object &code_obj, const py::object &co_name) const { if (py::isinstance(expected_grad_out)) { if (!py::isinstance(grad_out)) { hook_grad_.clear(); MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!"; } auto actual_out_tuple = py::cast(grad_out); auto expected_out_tuple = py::cast(expected_grad_out); if (actual_out_tuple.size() != expected_out_tuple.size()) { hook_grad_.clear(); MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size() << ", but it is " << actual_out_tuple.size(); } for (size_t i = 0; i < expected_out_tuple.size(); ++i) { CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i], code_obj, co_name); } } if (py::isinstance(expected_grad_out)) { if (!py::isinstance(grad_out)) { hook_grad_.clear(); MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got " << py::cast(grad_out.attr("__class__").attr("__name__")) << "."; } auto actual_out_tensor = py::cast(grad_out); auto expected_out_tensor = py::cast(expected_grad_out); MS_EXCEPTION_IF_NULL(actual_out_tensor); MS_EXCEPTION_IF_NULL(expected_out_tensor); if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) { hook_grad_.clear(); MS_EXCEPTION(ValueError) << "The output type of " << py::str(co_name) << " is not consistent with the expected, it should be " << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but got " << actual_out_tensor->GetShapeAndDataTypeInfo(); } } } BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const { if (backward_hook_fn_.size() > 1) { MS_LOG(EXCEPTION) << "Multiple registration of bprop function is not supported."; } SyncData(py_args); py::tuple converted_args(py_args.size()); ConvertCTensorToPyTensor(py_args, &converted_args); constexpr size_t non_inp_args_size = 2; // out and dout. auto inp_args_size = py_args.size() - non_inp_args_size; py::tuple input_args(inp_args_size); for (size_t i = 0; i < inp_args_size; ++i) { input_args[i] = py_args[i]; } // Run bprop function. auto inst = pynative::PynativeExecutor::GetInstance(); MS_EXCEPTION_IF_NULL(inst); try { MS_LOG(DEBUG) << "Run bprop function start."; py::tuple grads; for (const auto &elem : backward_hook_fn_) { inst->NewGraph(elem.second, input_args.cast()); py::object grads_obj = elem.second(*converted_args); grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_); inst->EndGraph(elem.second, grads_obj, input_args.cast()); } MS_LOG(DEBUG) << "Run bprop function end."; return std::make_shared(grads); } catch (std::exception &bt) { inst->ClearRes(); std::rethrow_exception(std::current_exception()); } } BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const { // Get the gradient passed to current bprop cut op. const auto args_size = py_args.size(); py::object grad_output = py_args[args_size - 1]; // Get the cell id. auto cell_id = GetValue(this->GetAttr(kCellIDAttrName)); auto iter = hook_grad_.find(cell_id); if (iter != hook_grad_.end()) { // The second bprop_cut used to hook output gradient of cell. for (const auto &elem : backward_hook_fn_) { py::object code_obj = py::getattr(elem.second, "__code__"); py::object co_name = py::getattr(code_obj, "co_name"); if (std::string(py::str(co_name)) == "staging_specialize") { py::object name_obj = py::getattr(elem.second, "__name__"); MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported."; } SyncData(grad_output); py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, iter->second, grad_output); py::object ret = elem.second(*hook_fn_args); if (!py::isinstance(ret)) { grad_output = ret; } CheckHookConsistency(grad_output, py_args[args_size - 1], code_obj, co_name); } (void)hook_grad_.erase(cell_id); } else { // The first bprop_cut used to hook input gradient of cell. SyncData(grad_output); hook_grad_[cell_id] = grad_output; } if (!py::isinstance(grad_output)) { grad_output = py::make_tuple(grad_output); } return std::make_shared(grad_output); } BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const { constexpr size_t grad_output_index = 2; py::object grad_output = py_args[grad_output_index]; for (const auto &elem : backward_hook_fn_) { py::object code_obj = py::getattr(elem.second, "__code__"); py::object co_name = py::getattr(code_obj, "co_name"); if (std::string(py::str(co_name)) == "staging_specialize") { py::object name_obj = py::getattr(elem.second, "__name__"); MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported."; } SyncData(grad_output); py::object ret = elem.second(py::make_tuple(grad_output)); if (!py::isinstance(ret)) { grad_output = ret; } CheckHookConsistency(grad_output, py_args[grad_output_index], code_obj, co_name); } grad_output = py::make_tuple(grad_output); return std::make_shared(grad_output); } BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { py::tuple py_args = ConvertDatatoPyTuple(args); bool is_bprop = this->HasAttr(kBpropAttrName); if (is_bprop) { return RunCellBpropFunction(py_args); } bool is_cell = this->HasAttr(kCellHookAttrName); if (is_cell) { return RunCellHookFunction(py_args); } return RunVariableHookFunction(py_args); } py::function PrimitivePy::GetComputeFunction() const { static const char *const compute_func_name = "vm_impl"; if (py::hasattr(python_obj_, compute_func_name)) { MS_LOG(DEBUG) << name() << " compute_func_name"; py::function fn = python_obj_.attr(compute_func_name).cast(); return fn; } static const std::string vm_module = "mindspore.ops.vm_impl_registry"; static const std::string get_vm_impl_fn = "get_vm_impl_fn"; MS_LOG(DEBUG) << name() << ": get_vm_impl_fn"; py::function get_fn = python_adapter::GetPyFn(vm_module, get_vm_impl_fn); py::function vm_fn = get_fn(python_obj_); if (py::isinstance(vm_fn)) { MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); vm_fn = mindspore::GetComputeFunction(Primitive::name()); } return vm_fn; } py::dict PrimitivePy::GetAttrDict() { py::dict attr_dict; for (auto &attr : attrs_) { attr_dict[py::str(attr.first)] = ValueToPyData(attr.second); } return attr_dict; } void PrimitivePy::CopyHookFunction(const PrimitivePyPtr &primitive_py) { MS_EXCEPTION_IF_NULL(primitive_py); const auto &backward_hook_fn = primitive_py->backward_hook_fn(); for (const auto &elem : backward_hook_fn) { AddBackwardHookFn(elem.first, elem.second); } if (primitive_py->HasAttr(kBpropAttrName)) { set_bprop_cls_name(primitive_py->bprop_cls_name_); (void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName)); } } BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const { auto py_args = ConvertDatatoPyTuple(args); auto result = this->RunPyComputeFunction(py_args); if (py::isinstance(result)) { return std::make_shared(nullptr); } return std::make_shared(result); } py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const { auto func = this->GetComputeFunction(); if (py::isinstance(func)) { return py::none(); } auto result = func(*py_args); return result; } bool PrimitivePy::HasComputeFunction() const { auto func = GetComputeFunction(); return !py::isinstance(func); } PrimitivePtr PrimitivePy::Clone() { auto clone_fn = python_obj_.attr("_clone"); py::object obj_adapter = clone_fn(); auto prim_adapter = obj_adapter.cast(); auto prim = std::make_shared(obj_adapter, prim_adapter); prim_adapter->set_attached_primitive(prim); return prim; } py::dict PrimitivePy::RunInfer(const py::tuple &args) { if (!HasPyObj()) { MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; } // Python obj could be replaced as None, so it will losed the original info when throw exception in python. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER)) { MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER; } auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER); return infer_fuc(*args); } void PrimitivePy::RunCheck(const py::tuple &args) { if (!HasPyObj()) { MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; } // Python obj could be replaced as None, so it will losed the original info when throw exception in python. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_CHECK)) { MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_CHECK; } auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK); (void)check_func(*args); } py::object PrimitivePy::RunInferValue(const py::tuple &args) { if (!HasPyObj()) { MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; } // Python obj could be replaced as None, so it will losed the original info when throw exception in python. if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER_VALUE)) { MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER_VALUE; } auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE); return infer_value(*args); } void PrimitivePy::ClearHookRes() { hook_grad_.clear(); } PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : name_(name) {} void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) { std::string attr_name = name; ValuePtr converted_ret = nullptr; if (py::isinstance(obj)) { MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed," << " not support py::module to be attribute value; primitive name: " << this->name_ << ", attribute name: " << attr_name << " attribute value: " << py::str(obj); } bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed," << " convert python obj to MindSpore obj failed; primitive name: " << this->name_ << ", attribute name:" << attr_name << ", attribute value:" << py::str(obj) << ", attribute type:" << py::cast(obj.attr("__class__").attr("__name__")); } if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) { attr_name = kOpAttrNameReplaceMap[attr_name]; } (void)CheckAndConvertUtils::ConvertAttrValueToInt(this->name_, name, &converted_ret); if (attr_name == "primitive_target") { MS_EXCEPTION_IF_NULL(converted_ret); if (!converted_ret->isa()) { MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive '" << this->name_ << "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got " << py::str(obj); } auto target = GetValue(converted_ret); if (!target.empty() && target != kCPUDevice && target != kGPUDevice && target != kAscendDevice && target != "Device") { MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive '" << this->name_ << "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got " << py::str(obj); } } attrs_[attr_name] = converted_ret; auto prim = attached_primitive_.lock(); if (prim != nullptr) { (void)prim->AddAttr(attr_name, converted_ret); } } void PrimitivePyAdapter::DelPyAttr(const py::str &name) { (void)attrs_.erase(name); auto prim = attached_primitive_.lock(); if (prim != nullptr) { (void)prim->DelAttr(name); } } py::dict PrimitivePyAdapter::GetAttrDict() { auto prim = attached_primitive_.lock(); if (prim != nullptr) { return prim->GetAttrDict(); } py::dict attr_dict; for (auto &attr : attrs_) { attr_dict[py::str(attr.first)] = ValueToPyData(attr.second); } return attr_dict; } void PrimitivePyAdapter::set_prim_type(const PrimType t) { prim_type_ = t; auto prim = attached_primitive_.lock(); if (prim != nullptr) { prim->set_prim_type(t); } } void PrimitivePyAdapter::set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; auto prim = attached_primitive_.lock(); if (prim != nullptr) { prim->set_const_prim(is_const_prim); } } void PrimitivePyAdapter::set_const_input_indexes(const std::vector &const_input_indexes) { const_input_indexes_ = const_input_indexes; auto prim = attached_primitive_.lock(); if (prim != nullptr) { prim->set_const_input_indexes(const_input_indexes); } } void PrimitivePyAdapter::set_signatures(const std::vector &signatures) { signatures_ = signatures; auto prim = attached_primitive_.lock(); if (prim != nullptr) { prim->set_signatures(signatures); } } int PrimitivePyAdapter::AddBackwardHookFn(const py::function &backward_hook_fn) { ++backward_hook_fn_key_; backward_hook_fn_[backward_hook_fn_key_] = backward_hook_fn; auto prim = attached_primitive_.lock(); if (prim != nullptr) { prim->AddBackwardHookFn(backward_hook_fn_key_, backward_hook_fn); } return backward_hook_fn_key_; } void PrimitivePyAdapter::RemoveBackwardHookFn(int key) { auto iter = backward_hook_fn_.find(key); if (iter != backward_hook_fn_.end()) { backward_hook_fn_.erase(iter); } auto prim = attached_primitive_.lock(); if (prim != nullptr) { prim->RemoveBackwardHookFn(key); } } void PrimitivePyAdapter::set_instance_name(const std::string &s) { instance_name_ = s; auto prim = attached_primitive_.lock(); if (prim != nullptr) { prim->set_instance_name(s); } } void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) { if (attached_primitive_.lock() != nullptr) { MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive."; } MS_EXCEPTION_IF_NULL(prim); attached_primitive_ = prim; } REGISTER_PYBIND_DEFINE( Primitive_, ([](const py::module *m) { (void)py::enum_(*m, "prim_type", py::arithmetic()) .value("unknown", PrimType::kPrimTypeUnknown) .value("builtin", PrimType::kPrimTypeBuiltIn) .value("py_infer_shape", PrimType::kPrimTypePyInfer) .value("user_custom", PrimType::kPrimTypeUserCustom) .value("py_infer_check", PrimType::kPrimTypePyCheck); (void)py::class_>(*m, "Primitive_") .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_) .def(py::init()) .def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr") .def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr") .def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr") .def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.") .def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.") .def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes, "Set primitive const input indexes.") .def("set_signatures", &PrimitivePyAdapter::set_signatures, "Set primitive inputs signature.") .def("add_backward_hook_fn", &PrimitivePyAdapter::AddBackwardHookFn, "Add primitive backward hook function.") .def("remove_backward_hook_fn", &PrimitivePyAdapter::RemoveBackwardHookFn, "Remove primitive backward hook function.") .def("set_instance_name", &PrimitivePyAdapter::set_instance_name, "Set primitive instance name."); })); } // namespace mindspore