|
|
|
@@ -1982,17 +1982,21 @@ class BatchDataset(Dataset): |
|
|
|
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses |
|
|
|
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, |
|
|
|
initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],)) |
|
|
|
|
|
|
|
idx = 0 |
|
|
|
global _OP_NAME |
|
|
|
global _OP_NAME, _OP_PROCESS, _LOCK |
|
|
|
op_id = _OP_NAME[str(self)] |
|
|
|
_manager = multiprocessing.Manager() |
|
|
|
_op_process = _manager.dict() |
|
|
|
_process_lock = _manager.Lock() |
|
|
|
process_id = {op_id: [self.num_parallel_workers, set()]} |
|
|
|
# obtain process id from multiprocessing.pool |
|
|
|
for pool in self.process_pool._pool: # pylint: disable=W0212 |
|
|
|
process_id[op_id][1].add(pool.pid) |
|
|
|
with _LOCK: |
|
|
|
_OP_PROCESS.update(process_id) |
|
|
|
|
|
|
|
# Wrap per_batch_map into _PythonCallable |
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, op_id, _op_process, _process_lock, |
|
|
|
self.num_parallel_workers, self.process_pool) |
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) |
|
|
|
self.hook = _ExceptHookHandler() |
|
|
|
atexit.register(_mp_pool_exit_preprocess, _manager) |
|
|
|
atexit.register(_mp_pool_exit_preprocess) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None: |
|
|
|
@@ -2201,19 +2205,13 @@ def _pyfunc_worker_init(pyfunc_list): |
|
|
|
|
|
|
|
# Pyfunc worker execution function |
|
|
|
# All exceptions will be raised to main processes |
|
|
|
def _pyfunc_worker_exec(index, op_id, mapping, lock, record, *args): |
|
|
|
def _pyfunc_worker_exec(index, *args): |
|
|
|
""" |
|
|
|
Internal function for call certain pyfunc in python process. |
|
|
|
""" |
|
|
|
# Some threads in multiprocess.pool can't process sigint signal, |
|
|
|
# and will occur hang problem, so ctrl+c will pass to parent process. |
|
|
|
signal.signal(signal.SIGINT, signal.SIG_IGN) |
|
|
|
if record: |
|
|
|
pid = os.getpid() |
|
|
|
with lock: |
|
|
|
data = mapping[op_id] |
|
|
|
data[1].add(pid) |
|
|
|
mapping[op_id] = data |
|
|
|
return _GLOBAL_PYFUNC_LIST[index](*args) |
|
|
|
|
|
|
|
|
|
|
|
@@ -2223,40 +2221,20 @@ class _PythonCallable: |
|
|
|
Internal Python function wrapper for multiprocessing pyfunc. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, py_callable, idx, op_id, mapping, lock, worker_num, pool=None): |
|
|
|
def __init__(self, py_callable, idx, pool=None): |
|
|
|
# Original Python callable from user. |
|
|
|
self.py_callable = py_callable |
|
|
|
# Process pool created for current iterator. |
|
|
|
self.pool = pool |
|
|
|
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST |
|
|
|
self.idx = idx |
|
|
|
self.op_id = op_id |
|
|
|
self.mapping = mapping |
|
|
|
self.lock = lock |
|
|
|
self.worker_num = worker_num |
|
|
|
self.record = True |
|
|
|
self.mapping[op_id] = [self.worker_num, set()] |
|
|
|
global _OP_PROCESS, _LOCK |
|
|
|
with _LOCK: |
|
|
|
_OP_PROCESS.update(self.mapping) |
|
|
|
|
|
|
|
def __call__(self, *args): |
|
|
|
if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 |
|
|
|
# This call will send the tensors along with Python callable index to the process pool. |
|
|
|
# Block, yield GIL. Current thread will reacquire GIL once result is returned. |
|
|
|
if self.record: |
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, self.mapping, self.lock, |
|
|
|
self.record, *args]) |
|
|
|
else: |
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, None, None, self.record, |
|
|
|
*args]) |
|
|
|
if self.record: |
|
|
|
data = self.mapping |
|
|
|
if len(data[self.op_id][1]) == self.worker_num: |
|
|
|
self.record = False |
|
|
|
global _OP_PROCESS, _LOCK |
|
|
|
with _LOCK: |
|
|
|
_OP_PROCESS.update(data) |
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) |
|
|
|
|
|
|
|
# todo this check might be wrong |
|
|
|
while check_iterator_cleanup() is False: |
|
|
|
try: |
|
|
|
@@ -2273,15 +2251,13 @@ class _PythonCallable: |
|
|
|
return self.py_callable(*args) |
|
|
|
|
|
|
|
|
|
|
|
def _mp_pool_exit_preprocess(manager=None): |
|
|
|
def _mp_pool_exit_preprocess(): |
|
|
|
if check_iterator_cleanup() is False: |
|
|
|
logger.info("Execution preprocessing process before map exit.") |
|
|
|
# Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async |
|
|
|
# applied to the multiprocessing task to prevent multiprocessing from hang when exiting |
|
|
|
_set_iterator_cleanup() |
|
|
|
time.sleep(3) |
|
|
|
if manager is not None: |
|
|
|
manager.shutdown() |
|
|
|
|
|
|
|
|
|
|
|
class _ExceptHookHandler: |
|
|
|
@@ -2385,26 +2361,29 @@ class MapDataset(Dataset): |
|
|
|
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses |
|
|
|
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, |
|
|
|
initializer=_pyfunc_worker_init, initargs=(callable_list,)) |
|
|
|
|
|
|
|
# Pass #2 |
|
|
|
global _OP_NAME |
|
|
|
op_id = _OP_NAME[str(self)] |
|
|
|
idx = 0 |
|
|
|
_manager = multiprocessing.Manager() |
|
|
|
_op_process = _manager.dict() |
|
|
|
_process_lock = _manager.Lock() |
|
|
|
global _OP_NAME, _OP_PROCESS, _LOCK |
|
|
|
op_id = _OP_NAME[str(self)] |
|
|
|
# obtain process id from multiprocessing.pool |
|
|
|
process_id = {op_id: [self.num_parallel_workers, set()]} |
|
|
|
for pool in self.process_pool._pool: # pylint: disable=W0212 |
|
|
|
process_id[op_id][1].add(pool.pid) |
|
|
|
with _LOCK: |
|
|
|
_OP_PROCESS.update(process_id) |
|
|
|
for op in self.operations: |
|
|
|
# our c transforms is now callable and should not be run in python multithreading |
|
|
|
if callable(op) and str(op).find("c_transform") < 0: |
|
|
|
# Wrap Python callable into _PythonCallable |
|
|
|
iter_specific_operations.append(_PythonCallable(op, idx, op_id, _op_process, _process_lock, |
|
|
|
self.num_parallel_workers, self.process_pool)) |
|
|
|
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) |
|
|
|
idx += 1 |
|
|
|
else: |
|
|
|
# CPP ops remain the same |
|
|
|
iter_specific_operations.append(op) |
|
|
|
self.operations = iter_specific_operations |
|
|
|
self.hook = _ExceptHookHandler() |
|
|
|
atexit.register(_mp_pool_exit_preprocess, _manager) |
|
|
|
atexit.register(_mp_pool_exit_preprocess) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None: |
|
|
|
|