|
|
@@ -38,7 +38,7 @@ from mindspore._c_expression import typing |
|
|
|
|
|
|
|
|
from mindspore import log as logger |
|
|
from mindspore import log as logger |
|
|
from . import samplers |
|
|
from . import samplers |
|
|
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator |
|
|
|
|
|
|
|
|
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup |
|
|
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ |
|
|
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ |
|
|
check_rename, check_numpyslicesdataset, check_device_send, \ |
|
|
check_rename, check_numpyslicesdataset, check_device_send, \ |
|
|
check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ |
|
|
check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ |
|
|
@@ -55,7 +55,6 @@ try: |
|
|
except ModuleNotFoundError: |
|
|
except ModuleNotFoundError: |
|
|
context = None |
|
|
context = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Shuffle(str, Enum): |
|
|
class Shuffle(str, Enum): |
|
|
GLOBAL: str = "global" |
|
|
GLOBAL: str = "global" |
|
|
FILES: str = "file" |
|
|
FILES: str = "file" |
|
|
@@ -2012,18 +2011,19 @@ class _PythonCallable: |
|
|
|
|
|
|
|
|
def __call__(self, *args): |
|
|
def __call__(self, *args): |
|
|
if self.pool is not None: |
|
|
if self.pool is not None: |
|
|
try: |
|
|
|
|
|
# This call will send the tensors along with Python callable index to the process pool. |
|
|
|
|
|
# Block, yield GIL. Current thread will reacquire GIL once result is returned. |
|
|
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) |
|
|
|
|
|
return result.get(60) |
|
|
|
|
|
except multiprocessing.TimeoutError: |
|
|
|
|
|
# Ensure c++ pyfunc threads exit normally if python sub-process is killed unnormally. |
|
|
|
|
|
return (None,) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
|
|
|
self.pool.terminate() |
|
|
|
|
|
self.pool.join() |
|
|
|
|
|
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") |
|
|
|
|
|
|
|
|
# This call will send the tensors along with Python callable index to the process pool. |
|
|
|
|
|
# Block, yield GIL. Current thread will reacquire GIL once result is returned. |
|
|
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) |
|
|
|
|
|
while check_iterator_cleanup() is False: |
|
|
|
|
|
try: |
|
|
|
|
|
return result.get(30) |
|
|
|
|
|
except multiprocessing.TimeoutError: |
|
|
|
|
|
continue |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
|
|
|
self.pool.terminate() |
|
|
|
|
|
self.pool.join() |
|
|
|
|
|
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") |
|
|
|
|
|
return (None,) |
|
|
# Invoke original Python callable in master process in case the pool is gone. |
|
|
# Invoke original Python callable in master process in case the pool is gone. |
|
|
return self.py_callable(*args) |
|
|
return self.py_callable(*args) |
|
|
|
|
|
|
|
|
@@ -3395,7 +3395,14 @@ class _GeneratorWorkerMp(multiprocessing.Process): |
|
|
""" |
|
|
""" |
|
|
Get function for worker result queue. Block with timeout. |
|
|
Get function for worker result queue. Block with timeout. |
|
|
""" |
|
|
""" |
|
|
return self.res_queue.get(timeout=10) |
|
|
|
|
|
|
|
|
while check_iterator_cleanup() is False: |
|
|
|
|
|
try: |
|
|
|
|
|
return self.res_queue.get(timeout=10) |
|
|
|
|
|
except multiprocessing.TimeoutError: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
raise Exception("Generator worker process timeout") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __del__(self): |
|
|
def __del__(self): |
|
|
self.terminate() |
|
|
self.terminate() |
|
|
|