From f926c0abe442a55a719c851b19181420228e57ea Mon Sep 17 00:00:00 2001 From: Eric Date: Thu, 3 Dec 2020 15:01:00 -0500 Subject: [PATCH] Added back hook for iterator closing --- mindspore/dataset/engine/datasets.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 83c1749bf5..d8646825f0 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1925,6 +1925,8 @@ class BatchDataset(Dataset): new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.hook = copy.deepcopy(self.hook, memodict) new_op.pad_info = copy.deepcopy(self.pad_info, memodict) + if hasattr(self, "__total_batch__"): + new_op.__total_batch__ = self.__total_batch__ return new_op # Iterator bootstrap will be called on iterator construction. @@ -1943,6 +1945,7 @@ class BatchDataset(Dataset): idx = 0 # Wrap per_batch_map into _PythonCallable self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) + self.hook = _ExceptHookHandler() def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: @@ -2209,15 +2212,12 @@ class _PythonCallable: class _ExceptHookHandler: - def __init__(self, pool): - self.__pool = pool + def __init__(self): sys.excepthook = self.__handler_exception def __handler_exception(self, type, value, tb): logger.error("Uncaught exception: ", exc_info=(type, value, tb)) - if self.__pool is not None: - _set_iterator_cleanup() - self.__pool.terminate() + _set_iterator_cleanup() class MapDataset(Dataset): @@ -2354,7 +2354,8 @@ class MapDataset(Dataset): # Pass #1, look for Python callables and build list for op in self.operations: - if callable(op): + # our c transforms is now callable and should not be run in python multithreading + if callable(op) and str(op).find("c_transform") < 0: callable_list.append(op) if callable_list: @@ -2366,7 +2367,8 @@ class MapDataset(Dataset): # Pass #2 idx = 0 for op in self.operations: - if callable(op): + # our c transforms is now callable and should not be run in python multithreading + if callable(op) and str(op).find("c_transform") < 0: # Wrap Python callable into _PythonCallable iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) idx += 1 @@ -2374,6 +2376,7 @@ class MapDataset(Dataset): # CPP ops remain the same iter_specific_operations.append(op) self.operations = iter_specific_operations + self.hook = _ExceptHookHandler() def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: