|
|
|
@@ -13,30 +13,30 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
#include "ir/param_value.h" |
|
|
|
#include "ir/param_info.h" |
|
|
|
#include "pybind11/pybind11.h" |
|
|
|
#include "pybind_api/api_register.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace py = pybind11; |
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) { |
|
|
|
(void)py::class_<ParamValue, ParamValuePtr>(*m, "ParamValue") |
|
|
|
REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { |
|
|
|
(void)py::class_<ParamInfo, ParamValuePtr>(*m, "ParamInfo") |
|
|
|
.def(py::init()) |
|
|
|
.def("clone", &ParamValue::Clone) |
|
|
|
.def_property("name", &ParamValue::name, &ParamValue::set_name) |
|
|
|
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad) |
|
|
|
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel, |
|
|
|
&ParamValue::set_layerwise_parallel) |
|
|
|
.def("clone", &ParamInfo::Clone) |
|
|
|
.def_property("name", &ParamInfo::name, &ParamInfo::set_name) |
|
|
|
.def_property("requires_grad", &ParamInfo::requires_grad, &ParamInfo::set_requires_grad) |
|
|
|
.def_property("layerwise_parallel", &ParamInfo::layerwise_parallel, |
|
|
|
&ParamInfo::set_layerwise_parallel) |
|
|
|
.def(py::pickle( |
|
|
|
[](const ParamValue &p) { // __getstate__ |
|
|
|
[](const ParamInfo &p) { // __getstate__ |
|
|
|
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); |
|
|
|
}, |
|
|
|
[](const py::tuple &t) { // __setstate__ |
|
|
|
if (t.size() != 6) { |
|
|
|
std::runtime_error("Invalid state for ParamValue!"); |
|
|
|
std::runtime_error("Invalid state for ParamInfo!"); |
|
|
|
} |
|
|
|
ParamValuePtr p = std::make_shared<ParamValue>(); |
|
|
|
ParamValuePtr p = std::make_shared<ParamInfo>(); |
|
|
|
p->set_name(t[1].cast<std::string>()); |
|
|
|
p->set_requires_grad(t[2].cast<bool>()); |
|
|
|
p->set_layerwise_parallel(t[3].cast<bool>()); |