diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index b00bfd6efe..db6f1e3c1a 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1921,6 +1921,8 @@ class BatchDataset(Dataset): new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.hook = copy.deepcopy(self.hook, memodict) new_op.pad_info = copy.deepcopy(self.pad_info, memodict) + if hasattr(self, "__total_batch__"): + new_op.__total_batch__ = self.__total_batch__ return new_op # Iterator bootstrap will be called on iterator construction. @@ -1939,6 +1941,7 @@ class BatchDataset(Dataset): idx = 0 # Wrap per_batch_map into _PythonCallable self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) + self.hook = _ExceptHookHandler() def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: @@ -2205,15 +2208,12 @@ class _PythonCallable: class _ExceptHookHandler: - def __init__(self, pool): - self.__pool = pool + def __init__(self): sys.excepthook = self.__handler_exception def __handler_exception(self, type, value, tb): logger.error("Uncaught exception: ", exc_info=(type, value, tb)) - if self.__pool is not None: - _set_iterator_cleanup() - self.__pool.terminate() + _set_iterator_cleanup() class MapDataset(Dataset): @@ -2350,7 +2350,8 @@ class MapDataset(Dataset): # Pass #1, look for Python callables and build list for op in self.operations: - if callable(op): + # our c transforms is now callable and should not be run in python multithreading + if callable(op) and str(op).find("c_transform") < 0: callable_list.append(op) if callable_list: @@ -2362,7 +2363,8 @@ class MapDataset(Dataset): # Pass #2 idx = 0 for op in self.operations: - if callable(op): + # our c transforms is now callable and should not be run in python multithreading + if callable(op) and str(op).find("c_transform") < 0: # Wrap Python callable into _PythonCallable iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) idx += 1 @@ -2370,6 +2372,7 @@ class MapDataset(Dataset): # CPP ops remain the same iter_specific_operations.append(op) self.operations = iter_specific_operations + self.hook = _ExceptHookHandler() def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: