| @@ -27,6 +27,7 @@ import multiprocessing | |||||
| import queue | import queue | ||||
| from enum import Enum | from enum import Enum | ||||
| from importlib import import_module | from importlib import import_module | ||||
| import sys | |||||
| import threading | import threading | ||||
| import copy | import copy | ||||
| @@ -42,7 +43,8 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched | |||||
| import mindspore.dataset.transforms.py_transforms as py_transforms | import mindspore.dataset.transforms.py_transforms as py_transforms | ||||
| from . import samplers | from . import samplers | ||||
| from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup | |||||
| from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup, \ | |||||
| _set_iterator_cleanup | |||||
| from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | ||||
| check_rename, check_numpyslicesdataset, check_device_send, \ | check_rename, check_numpyslicesdataset, check_device_send, \ | ||||
| check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| @@ -2040,7 +2042,8 @@ class _PythonCallable: | |||||
| except multiprocessing.TimeoutError: | except multiprocessing.TimeoutError: | ||||
| continue | continue | ||||
| except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
| self.pool.terminate() | |||||
| _set_iterator_cleanup() | |||||
| self.pool.close() | |||||
| self.pool.join() | self.pool.join() | ||||
| raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") | raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") | ||||
| return (None,) | return (None,) | ||||
| @@ -2048,6 +2051,18 @@ class _PythonCallable: | |||||
| return self.py_callable(*args) | return self.py_callable(*args) | ||||
| class _ExceptHookHandler: | |||||
| def __init__(self, pool): | |||||
| self.__pool = pool | |||||
| sys.excepthook = self.__handler_exception | |||||
| def __handler_exception(self, type, value, tb): | |||||
| logger.error("Uncaught exception: ", exc_info=(type, value, tb)) | |||||
| if self.__pool is not None: | |||||
| _set_iterator_cleanup() | |||||
| self.__pool.terminate() | |||||
| class MapDataset(DatasetOp): | class MapDataset(DatasetOp): | ||||
| """ | """ | ||||
| The result of applying the Map operator to the input Dataset. | The result of applying the Map operator to the input Dataset. | ||||
| @@ -2124,6 +2139,7 @@ class MapDataset(DatasetOp): | |||||
| callbacks = [callbacks] | callbacks = [callbacks] | ||||
| self.callbacks = callbacks | self.callbacks = callbacks | ||||
| self.hook = None | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| @@ -2163,6 +2179,7 @@ class MapDataset(DatasetOp): | |||||
| new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) | new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) | ||||
| new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) | new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) | ||||
| new_op.cache = copy.deepcopy(self.cache, memodict) | new_op.cache = copy.deepcopy(self.cache, memodict) | ||||
| new_op.hook = copy.deepcopy(self.hook, memodict) | |||||
| new_op.operations = self.operations | new_op.operations = self.operations | ||||
| new_op.dataset_size = self.dataset_size | new_op.dataset_size = self.dataset_size | ||||
| new_op.callbacks = self.callbacks | new_op.callbacks = self.callbacks | ||||
| @@ -2203,10 +2220,11 @@ class MapDataset(DatasetOp): | |||||
| # CPP ops remain the same | # CPP ops remain the same | ||||
| iter_specific_operations.append(op) | iter_specific_operations.append(op) | ||||
| self.operations = iter_specific_operations | self.operations = iter_specific_operations | ||||
| self.hook = _ExceptHookHandler(self.process_pool) | |||||
| def __del__(self): | def __del__(self): | ||||
| if hasattr(self, 'process_pool') and self.process_pool is not None: | if hasattr(self, 'process_pool') and self.process_pool is not None: | ||||
| self.process_pool.terminate() | |||||
| self.process_pool.close() | |||||
| class FilterDataset(DatasetOp): | class FilterDataset(DatasetOp): | ||||