diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc index 65204417ed..846cc22f51 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc @@ -61,6 +61,7 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { py::object ret_py_ele = ret_py_tuple[i]; // Object is none if pyfunc timeout if (ret_py_ele.is_none()) { + MS_LOG(INFO) << "PyFunc execute time out"; goto TimeoutError; } if (!py::isinstance(ret_py_ele)) { @@ -92,7 +93,7 @@ ShapeMisMatch: goto ComputeReturn; TimeoutError: - ret = Status(StatusCode::kTimeOut, "PyFunc timeout"); + ret = Status(StatusCode::kTimeOut, "PyFunc execute time out"); goto ComputeReturn; } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ec1cb187ef..8f1d603e92 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -38,7 +38,7 @@ from mindspore._c_expression import typing from mindspore import log as logger from . import samplers -from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator +from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, check_device_send, \ check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ @@ -55,7 +55,6 @@ try: except ModuleNotFoundError: context = None - class Shuffle(str, Enum): GLOBAL: str = "global" FILES: str = "file" @@ -2012,18 +2011,19 @@ class _PythonCallable: def __call__(self, *args): if self.pool is not None: - try: - # This call will send the tensors along with Python callable index to the process pool. - # Block, yield GIL. Current thread will reacquire GIL once result is returned. - result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) - return result.get(60) - except multiprocessing.TimeoutError: - # Ensure c++ pyfunc threads exit normally if python sub-process is killed unnormally. - return (None,) - except KeyboardInterrupt: - self.pool.terminate() - self.pool.join() - raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + # This call will send the tensors along with Python callable index to the process pool. + # Block, yield GIL. Current thread will reacquire GIL once result is returned. + result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) + while check_iterator_cleanup() is False: + try: + return result.get(30) + except multiprocessing.TimeoutError: + continue + except KeyboardInterrupt: + self.pool.terminate() + self.pool.join() + raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + return (None,) # Invoke original Python callable in master process in case the pool is gone. return self.py_callable(*args) @@ -3395,7 +3395,14 @@ class _GeneratorWorkerMp(multiprocessing.Process): """ Get function for worker result queue. Block with timeout. """ - return self.res_queue.get(timeout=10) + while check_iterator_cleanup() is False: + try: + return self.res_queue.get(timeout=10) + except multiprocessing.TimeoutError: + continue + + raise Exception("Generator worker process timeout") + def __del__(self): self.terminate() diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 6642a0396b..2bc97ec8b5 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -27,16 +27,30 @@ from mindspore import log as logger from . import datasets as de +_ITERATOR_CLEANUP = False + +def _set_iterator_cleanup(): + global _ITERATOR_CLEANUP + _ITERATOR_CLEANUP = True + +def _unset_iterator_cleanup(): + global _ITERATOR_CLEANUP + _ITERATOR_CLEANUP = False + +def check_iterator_cleanup(): + global _ITERATOR_CLEANUP + return _ITERATOR_CLEANUP + ITERATORS_LIST = list() def _cleanup(): """Release all the Iterator.""" + _set_iterator_cleanup() for itr_ref in ITERATORS_LIST: itr = itr_ref() if itr is not None: itr.release() - def alter_tree(node): """Traversing the Python dataset tree/graph to perform some alteration to some specific nodes.""" if not node.children: @@ -71,6 +85,7 @@ class Iterator: self.num_epochs = num_epochs self.output_numpy = output_numpy ITERATORS_LIST.append(weakref.ref(self)) + _unset_iterator_cleanup() # create a copy of tree and work on it. self.dataset = copy.deepcopy(dataset) self.ori_dataset = dataset