|
|
|
@@ -27,6 +27,7 @@ import signal |
|
|
|
import time |
|
|
|
import uuid |
|
|
|
import multiprocessing |
|
|
|
from multiprocessing.pool import RUN |
|
|
|
import queue |
|
|
|
from enum import Enum |
|
|
|
from functools import partial |
|
|
|
@@ -1997,6 +1998,9 @@ class BatchDataset(Dataset): |
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) |
|
|
|
self.hook = _ExceptHookHandler() |
|
|
|
atexit.register(_mp_pool_exit_preprocess) |
|
|
|
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. |
|
|
|
if sys.version_info >= (3, 8): |
|
|
|
atexit.register(self.process_pool.close) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None: |
|
|
|
@@ -2230,7 +2234,11 @@ class _PythonCallable: |
|
|
|
self.idx = idx |
|
|
|
|
|
|
|
def __call__(self, *args): |
|
|
|
if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 |
|
|
|
# note here: the RUN state of python3.7 and python3.8 is different: |
|
|
|
# python3.7: RUN = 0 |
|
|
|
# python3.8: RUN = "RUN" |
|
|
|
# so we use self.pool._state == RUN instead and we can't use _state == 0 any more. |
|
|
|
if self.pool is not None and self.pool._state == RUN and check_iterator_cleanup() is False: # pylint: disable=W0212 |
|
|
|
# This call will send the tensors along with Python callable index to the process pool. |
|
|
|
# Block, yield GIL. Current thread will reacquire GIL once result is returned. |
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) |
|
|
|
@@ -2384,6 +2392,9 @@ class MapDataset(Dataset): |
|
|
|
self.operations = iter_specific_operations |
|
|
|
self.hook = _ExceptHookHandler() |
|
|
|
atexit.register(_mp_pool_exit_preprocess) |
|
|
|
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. |
|
|
|
if sys.version_info >= (3, 8): |
|
|
|
atexit.register(self.process_pool.close) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None: |
|
|
|
|