| @@ -27,6 +27,7 @@ import multiprocessing | |||
| import queue | |||
| from enum import Enum | |||
| from importlib import import_module | |||
| import sys | |||
| import threading | |||
| import copy | |||
| @@ -42,7 +43,8 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched | |||
| import mindspore.dataset.transforms.py_transforms as py_transforms | |||
| from . import samplers | |||
| from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup | |||
| from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup, \ | |||
| _set_iterator_cleanup | |||
| from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | |||
| check_rename, check_numpyslicesdataset, check_device_send, \ | |||
| check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| @@ -2040,7 +2042,8 @@ class _PythonCallable: | |||
| except multiprocessing.TimeoutError: | |||
| continue | |||
| except KeyboardInterrupt: | |||
| self.pool.terminate() | |||
| _set_iterator_cleanup() | |||
| self.pool.close() | |||
| self.pool.join() | |||
| raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") | |||
| return (None,) | |||
| @@ -2048,6 +2051,18 @@ class _PythonCallable: | |||
| return self.py_callable(*args) | |||
| class _ExceptHookHandler: | |||
| def __init__(self, pool): | |||
| self.__pool = pool | |||
| sys.excepthook = self.__handler_exception | |||
| def __handler_exception(self, type, value, tb): | |||
| logger.error("Uncaught exception: ", exc_info=(type, value, tb)) | |||
| if self.__pool is not None: | |||
| _set_iterator_cleanup() | |||
| self.__pool.terminate() | |||
| class MapDataset(DatasetOp): | |||
| """ | |||
| The result of applying the Map operator to the input Dataset. | |||
| @@ -2124,6 +2139,7 @@ class MapDataset(DatasetOp): | |||
| callbacks = [callbacks] | |||
| self.callbacks = callbacks | |||
| self.hook = None | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -2163,6 +2179,7 @@ class MapDataset(DatasetOp): | |||
| new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) | |||
| new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) | |||
| new_op.cache = copy.deepcopy(self.cache, memodict) | |||
| new_op.hook = copy.deepcopy(self.hook, memodict) | |||
| new_op.operations = self.operations | |||
| new_op.dataset_size = self.dataset_size | |||
| new_op.callbacks = self.callbacks | |||
| @@ -2203,10 +2220,11 @@ class MapDataset(DatasetOp): | |||
| # CPP ops remain the same | |||
| iter_specific_operations.append(op) | |||
| self.operations = iter_specific_operations | |||
| self.hook = _ExceptHookHandler(self.process_pool) | |||
| def __del__(self): | |||
| if hasattr(self, 'process_pool') and self.process_pool is not None: | |||
| self.process_pool.terminate() | |||
| self.process_pool.close() | |||
| class FilterDataset(DatasetOp): | |||