Merge pull request !378 from h.farahat/multi_itrtags/v0.2.0-alpha
| @@ -225,11 +225,13 @@ void bindTensor(py::module *m) { | |||||
| (void)py::class_<DataType>(*m, "DataType") | (void)py::class_<DataType>(*m, "DataType") | ||||
| .def(py::init<std::string>()) | .def(py::init<std::string>()) | ||||
| .def(py::self == py::self) | .def(py::self == py::self) | ||||
| .def("__str__", &DataType::ToString); | |||||
| .def("__str__", &DataType::ToString) | |||||
| .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); | |||||
| } | } | ||||
| void bindTensorOps1(py::module *m) { | void bindTensorOps1(py::module *m) { | ||||
| (void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp"); | |||||
| (void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp") | |||||
| .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); | |||||
| (void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>( | (void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>( | ||||
| *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") | *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") | ||||
| @@ -15,6 +15,8 @@ | |||||
| """Built-in iterators. | """Built-in iterators. | ||||
| """ | """ | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| import copy | |||||
| import weakref | |||||
| from mindspore._c_dataengine import DEPipeline | from mindspore._c_dataengine import DEPipeline | ||||
| from mindspore._c_dataengine import OpName | from mindspore._c_dataengine import OpName | ||||
| @@ -27,7 +29,9 @@ ITERATORS_LIST = list() | |||||
| def _cleanup(): | def _cleanup(): | ||||
| for itr in ITERATORS_LIST: | for itr in ITERATORS_LIST: | ||||
| itr.release() | |||||
| iter_ref = itr() | |||||
| if itr is not None: | |||||
| iter_ref.release() | |||||
| def alter_tree(node): | def alter_tree(node): | ||||
| @@ -73,8 +77,10 @@ class Iterator: | |||||
| """ | """ | ||||
| def __init__(self, dataset): | def __init__(self, dataset): | ||||
| ITERATORS_LIST.append(self) | |||||
| self.dataset = alter_tree(dataset) | |||||
| ITERATORS_LIST.append(weakref.ref(self)) | |||||
| # create a copy of tree and work on it. | |||||
| self.dataset = copy.deepcopy(dataset) | |||||
| self.dataset = alter_tree(self.dataset) | |||||
| if not self.__is_tree(): | if not self.__is_tree(): | ||||
| raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") | raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") | ||||
| self.depipeline = DEPipeline() | self.depipeline = DEPipeline() | ||||