Browse Source

!5944 [bug]fix bug in cell copy & test case

Merge pull request !5944 from vlne-v1/bug-cell-copy
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a2b329828d
2 changed files with 20 additions and 1 deletions
  1. +14
    -1
      mindspore/ccsrc/pybind_api/ir/cell_py.cc
  2. +6
    -0
      tests/ut/python/nn/test_cell.py

+ 14
- 1
mindspore/ccsrc/pybind_api/ir/cell_py.cc View File

@@ -45,6 +45,19 @@ REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) {
.def("_del_attr", &Cell::DelAttr, "Delete Cell attr.")
.def(
"construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; },
"construct");
"construct")
.def(py::pickle(
[](const Cell &cell) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::str(cell.name()));
},
[](const py::tuple &tup) { // __setstate__
if (tup.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
Cell data(tup[0].cast<std::string>());
return data;
}));
}));
} // namespace mindspore

+ 6
- 0
tests/ut/python/nn/test_cell.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
""" test cell """
import copy
import numpy as np
import pytest

@@ -200,6 +201,11 @@ def test_exceptions():
m.construct()


def test_cell_copy():
net = ConvNet()
copy.deepcopy(net)


def test_del():
""" test_del """
ta = Tensor(np.ones([2, 3]))


Loading…
Cancel
Save