| @@ -17,6 +17,7 @@ | |||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| import copy | import copy | ||||
| import weakref | import weakref | ||||
| from importlib import import_module | |||||
| from mindspore._c_dataengine import DEPipeline | from mindspore._c_dataengine import DEPipeline | ||||
| from mindspore._c_dataengine import OpName | from mindspore._c_dataengine import OpName | ||||
| @@ -24,14 +25,29 @@ from mindspore._c_dataengine import OpName | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from . import datasets as de | from . import datasets as de | ||||
| try: | |||||
| context = import_module("mindspore.context") | |||||
| except ModuleNotFoundError: | |||||
| context = None | |||||
| ITERATORS_LIST = list() | ITERATORS_LIST = list() | ||||
| def _cleanup(): | def _cleanup(): | ||||
| """Release all the Iterator.""" | |||||
| for itr_ref in ITERATORS_LIST: | for itr_ref in ITERATORS_LIST: | ||||
| itr = itr_ref() | |||||
| if itr is not None: | |||||
| itr.release() | |||||
| if context: | |||||
| device_type = context.get_context("device_target") | |||||
| if device_type == "GPU": | |||||
| itr_ref.release() | |||||
| else: | |||||
| itr = itr_ref() | |||||
| if itr is not None: | |||||
| itr.release() | |||||
| else: | |||||
| itr = itr_ref() | |||||
| if itr is not None: | |||||
| itr.release() | |||||
| def alter_tree(node): | def alter_tree(node): | ||||
| @@ -85,7 +101,14 @@ class Iterator: | |||||
| """ | """ | ||||
| def __init__(self, dataset): | def __init__(self, dataset): | ||||
| ITERATORS_LIST.append(weakref.ref(self)) | |||||
| if context: | |||||
| device_type = context.get_context("device_target") | |||||
| if device_type == "GPU": | |||||
| ITERATORS_LIST.append(self) | |||||
| else: | |||||
| ITERATORS_LIST.append(weakref.ref(self)) | |||||
| else: | |||||
| ITERATORS_LIST.append(weakref.ref(self)) | |||||
| # create a copy of tree and work on it. | # create a copy of tree and work on it. | ||||
| self.dataset = copy.deepcopy(dataset) | self.dataset = copy.deepcopy(dataset) | ||||
| self.dataset = alter_tree(self.dataset) | self.dataset = alter_tree(self.dataset) | ||||