|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
from abc import abstractmethod |
|
|
|
import copy |
|
|
|
import weakref |
|
|
|
from importlib import import_module |
|
|
|
|
|
|
|
from mindspore._c_dataengine import DEPipeline |
|
|
|
from mindspore._c_dataengine import OpName |
|
|
|
@@ -24,14 +25,29 @@ from mindspore._c_dataengine import OpName |
|
|
|
from mindspore import log as logger |
|
|
|
from . import datasets as de |
|
|
|
|
|
|
|
try: |
|
|
|
context = import_module("mindspore.context") |
|
|
|
except ModuleNotFoundError: |
|
|
|
context = None |
|
|
|
|
|
|
|
ITERATORS_LIST = list() |
|
|
|
|
|
|
|
|
|
|
|
def _cleanup(): |
|
|
|
"""Release all the Iterator.""" |
|
|
|
for itr_ref in ITERATORS_LIST: |
|
|
|
itr = itr_ref() |
|
|
|
if itr is not None: |
|
|
|
itr.release() |
|
|
|
if context: |
|
|
|
device_type = context.get_context("device_target") |
|
|
|
if device_type == "GPU": |
|
|
|
itr_ref.release() |
|
|
|
else: |
|
|
|
itr = itr_ref() |
|
|
|
if itr is not None: |
|
|
|
itr.release() |
|
|
|
else: |
|
|
|
itr = itr_ref() |
|
|
|
if itr is not None: |
|
|
|
itr.release() |
|
|
|
|
|
|
|
|
|
|
|
def alter_tree(node): |
|
|
|
@@ -85,7 +101,14 @@ class Iterator: |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset): |
|
|
|
ITERATORS_LIST.append(weakref.ref(self)) |
|
|
|
if context: |
|
|
|
device_type = context.get_context("device_target") |
|
|
|
if device_type == "GPU": |
|
|
|
ITERATORS_LIST.append(self) |
|
|
|
else: |
|
|
|
ITERATORS_LIST.append(weakref.ref(self)) |
|
|
|
else: |
|
|
|
ITERATORS_LIST.append(weakref.ref(self)) |
|
|
|
# create a copy of tree and work on it. |
|
|
|
self.dataset = copy.deepcopy(dataset) |
|
|
|
self.dataset = alter_tree(self.dataset) |
|
|
|
|