|
|
|
@@ -28,12 +28,14 @@ |
|
|
|
#include "utils/convert_utils_py.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#include "utils/primitive_utils.h" |
|
|
|
#include "pipeline/jit/resource.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace { |
|
|
|
constexpr auto kBpropAttrName = "bprop"; |
|
|
|
constexpr auto kCellHookAttrName = "cell_hook"; |
|
|
|
constexpr auto kCellIDAttrName = "cell_id"; |
|
|
|
|
|
|
|
void SyncData(const py::object &arg) { |
|
|
|
if (py::isinstance<py::tuple>(arg)) { |
|
|
|
py::tuple arg_list = py::cast<py::tuple>(arg); |
|
|
|
@@ -49,6 +51,12 @@ void SyncData(const py::object &arg) { |
|
|
|
} // namespace |
|
|
|
std::map<std::string, py::object> PrimitivePy::hook_grad_; |
|
|
|
|
|
|
|
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj) |
|
|
|
: Primitive(name, false), python_obj_(python_obj), signatures_() { |
|
|
|
pipeline::Resource::RecordPrimitivePy(this); |
|
|
|
} |
|
|
|
PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); } |
|
|
|
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } |
|
|
|
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { |
|
|
|
signatures_ = signatures; |
|
|
|
set_has_signature(true); |
|
|
|
|