| @@ -1982,17 +1982,21 @@ class BatchDataset(Dataset): | |||||
| # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses | # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses | ||||
| self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, | self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, | ||||
| initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],)) | initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],)) | ||||
| idx = 0 | idx = 0 | ||||
| global _OP_NAME | |||||
| global _OP_NAME, _OP_PROCESS, _LOCK | |||||
| op_id = _OP_NAME[str(self)] | op_id = _OP_NAME[str(self)] | ||||
| _manager = multiprocessing.Manager() | |||||
| _op_process = _manager.dict() | |||||
| _process_lock = _manager.Lock() | |||||
| process_id = {op_id: [self.num_parallel_workers, set()]} | |||||
| # obtain process id from multiprocessing.pool | |||||
| for pool in self.process_pool._pool: # pylint: disable=W0212 | |||||
| process_id[op_id][1].add(pool.pid) | |||||
| with _LOCK: | |||||
| _OP_PROCESS.update(process_id) | |||||
| # Wrap per_batch_map into _PythonCallable | # Wrap per_batch_map into _PythonCallable | ||||
| self.per_batch_map = _PythonCallable(self.per_batch_map, idx, op_id, _op_process, _process_lock, | |||||
| self.num_parallel_workers, self.process_pool) | |||||
| self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) | |||||
| self.hook = _ExceptHookHandler() | self.hook = _ExceptHookHandler() | ||||
| atexit.register(_mp_pool_exit_preprocess, _manager) | |||||
| atexit.register(_mp_pool_exit_preprocess) | |||||
| def __del__(self): | def __del__(self): | ||||
| if hasattr(self, 'process_pool') and self.process_pool is not None: | if hasattr(self, 'process_pool') and self.process_pool is not None: | ||||
| @@ -2201,19 +2205,13 @@ def _pyfunc_worker_init(pyfunc_list): | |||||
| # Pyfunc worker execution function | # Pyfunc worker execution function | ||||
| # All exceptions will be raised to main processes | # All exceptions will be raised to main processes | ||||
| def _pyfunc_worker_exec(index, op_id, mapping, lock, record, *args): | |||||
| def _pyfunc_worker_exec(index, *args): | |||||
| """ | """ | ||||
| Internal function for call certain pyfunc in python process. | Internal function for call certain pyfunc in python process. | ||||
| """ | """ | ||||
| # Some threads in multiprocess.pool can't process sigint signal, | # Some threads in multiprocess.pool can't process sigint signal, | ||||
| # and will occur hang problem, so ctrl+c will pass to parent process. | # and will occur hang problem, so ctrl+c will pass to parent process. | ||||
| signal.signal(signal.SIGINT, signal.SIG_IGN) | signal.signal(signal.SIGINT, signal.SIG_IGN) | ||||
| if record: | |||||
| pid = os.getpid() | |||||
| with lock: | |||||
| data = mapping[op_id] | |||||
| data[1].add(pid) | |||||
| mapping[op_id] = data | |||||
| return _GLOBAL_PYFUNC_LIST[index](*args) | return _GLOBAL_PYFUNC_LIST[index](*args) | ||||
| @@ -2223,40 +2221,20 @@ class _PythonCallable: | |||||
| Internal Python function wrapper for multiprocessing pyfunc. | Internal Python function wrapper for multiprocessing pyfunc. | ||||
| """ | """ | ||||
| def __init__(self, py_callable, idx, op_id, mapping, lock, worker_num, pool=None): | |||||
| def __init__(self, py_callable, idx, pool=None): | |||||
| # Original Python callable from user. | # Original Python callable from user. | ||||
| self.py_callable = py_callable | self.py_callable = py_callable | ||||
| # Process pool created for current iterator. | # Process pool created for current iterator. | ||||
| self.pool = pool | self.pool = pool | ||||
| # Python callable index for subprocess _GLOBAL_PYFUNC_LIST | # Python callable index for subprocess _GLOBAL_PYFUNC_LIST | ||||
| self.idx = idx | self.idx = idx | ||||
| self.op_id = op_id | |||||
| self.mapping = mapping | |||||
| self.lock = lock | |||||
| self.worker_num = worker_num | |||||
| self.record = True | |||||
| self.mapping[op_id] = [self.worker_num, set()] | |||||
| global _OP_PROCESS, _LOCK | |||||
| with _LOCK: | |||||
| _OP_PROCESS.update(self.mapping) | |||||
| def __call__(self, *args): | def __call__(self, *args): | ||||
| if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 | if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 | ||||
| # This call will send the tensors along with Python callable index to the process pool. | # This call will send the tensors along with Python callable index to the process pool. | ||||
| # Block, yield GIL. Current thread will reacquire GIL once result is returned. | # Block, yield GIL. Current thread will reacquire GIL once result is returned. | ||||
| if self.record: | |||||
| result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, self.mapping, self.lock, | |||||
| self.record, *args]) | |||||
| else: | |||||
| result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, None, None, self.record, | |||||
| *args]) | |||||
| if self.record: | |||||
| data = self.mapping | |||||
| if len(data[self.op_id][1]) == self.worker_num: | |||||
| self.record = False | |||||
| global _OP_PROCESS, _LOCK | |||||
| with _LOCK: | |||||
| _OP_PROCESS.update(data) | |||||
| result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) | |||||
| # todo this check might be wrong | # todo this check might be wrong | ||||
| while check_iterator_cleanup() is False: | while check_iterator_cleanup() is False: | ||||
| try: | try: | ||||
| @@ -2273,15 +2251,13 @@ class _PythonCallable: | |||||
| return self.py_callable(*args) | return self.py_callable(*args) | ||||
| def _mp_pool_exit_preprocess(manager=None): | |||||
| def _mp_pool_exit_preprocess(): | |||||
| if check_iterator_cleanup() is False: | if check_iterator_cleanup() is False: | ||||
| logger.info("Execution preprocessing process before map exit.") | logger.info("Execution preprocessing process before map exit.") | ||||
| # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async | # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async | ||||
| # applied to the multiprocessing task to prevent multiprocessing from hang when exiting | # applied to the multiprocessing task to prevent multiprocessing from hang when exiting | ||||
| _set_iterator_cleanup() | _set_iterator_cleanup() | ||||
| time.sleep(3) | time.sleep(3) | ||||
| if manager is not None: | |||||
| manager.shutdown() | |||||
| class _ExceptHookHandler: | class _ExceptHookHandler: | ||||
| @@ -2385,26 +2361,29 @@ class MapDataset(Dataset): | |||||
| # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses | # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses | ||||
| self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, | self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, | ||||
| initializer=_pyfunc_worker_init, initargs=(callable_list,)) | initializer=_pyfunc_worker_init, initargs=(callable_list,)) | ||||
| # Pass #2 | # Pass #2 | ||||
| global _OP_NAME | |||||
| op_id = _OP_NAME[str(self)] | |||||
| idx = 0 | idx = 0 | ||||
| _manager = multiprocessing.Manager() | |||||
| _op_process = _manager.dict() | |||||
| _process_lock = _manager.Lock() | |||||
| global _OP_NAME, _OP_PROCESS, _LOCK | |||||
| op_id = _OP_NAME[str(self)] | |||||
| # obtain process id from multiprocessing.pool | |||||
| process_id = {op_id: [self.num_parallel_workers, set()]} | |||||
| for pool in self.process_pool._pool: # pylint: disable=W0212 | |||||
| process_id[op_id][1].add(pool.pid) | |||||
| with _LOCK: | |||||
| _OP_PROCESS.update(process_id) | |||||
| for op in self.operations: | for op in self.operations: | ||||
| # our c transforms is now callable and should not be run in python multithreading | # our c transforms is now callable and should not be run in python multithreading | ||||
| if callable(op) and str(op).find("c_transform") < 0: | if callable(op) and str(op).find("c_transform") < 0: | ||||
| # Wrap Python callable into _PythonCallable | # Wrap Python callable into _PythonCallable | ||||
| iter_specific_operations.append(_PythonCallable(op, idx, op_id, _op_process, _process_lock, | |||||
| self.num_parallel_workers, self.process_pool)) | |||||
| iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) | |||||
| idx += 1 | idx += 1 | ||||
| else: | else: | ||||
| # CPP ops remain the same | # CPP ops remain the same | ||||
| iter_specific_operations.append(op) | iter_specific_operations.append(op) | ||||
| self.operations = iter_specific_operations | self.operations = iter_specific_operations | ||||
| self.hook = _ExceptHookHandler() | self.hook = _ExceptHookHandler() | ||||
| atexit.register(_mp_pool_exit_preprocess, _manager) | |||||
| atexit.register(_mp_pool_exit_preprocess) | |||||
| def __del__(self): | def __del__(self): | ||||
| if hasattr(self, 'process_pool') and self.process_pool is not None: | if hasattr(self, 'process_pool') and self.process_pool is not None: | ||||