| @@ -1982,17 +1982,21 @@ class BatchDataset(Dataset): | |||
| # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses | |||
| self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, | |||
| initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],)) | |||
| idx = 0 | |||
| global _OP_NAME | |||
| global _OP_NAME, _OP_PROCESS, _LOCK | |||
| op_id = _OP_NAME[str(self)] | |||
| _manager = multiprocessing.Manager() | |||
| _op_process = _manager.dict() | |||
| _process_lock = _manager.Lock() | |||
| process_id = {op_id: [self.num_parallel_workers, set()]} | |||
| # obtain process id from multiprocessing.pool | |||
| for pool in self.process_pool._pool: # pylint: disable=W0212 | |||
| process_id[op_id][1].add(pool.pid) | |||
| with _LOCK: | |||
| _OP_PROCESS.update(process_id) | |||
| # Wrap per_batch_map into _PythonCallable | |||
| self.per_batch_map = _PythonCallable(self.per_batch_map, idx, op_id, _op_process, _process_lock, | |||
| self.num_parallel_workers, self.process_pool) | |||
| self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) | |||
| self.hook = _ExceptHookHandler() | |||
| atexit.register(_mp_pool_exit_preprocess, _manager) | |||
| atexit.register(_mp_pool_exit_preprocess) | |||
| def __del__(self): | |||
| if hasattr(self, 'process_pool') and self.process_pool is not None: | |||
| @@ -2201,19 +2205,13 @@ def _pyfunc_worker_init(pyfunc_list): | |||
| # Pyfunc worker execution function | |||
| # All exceptions will be raised to main processes | |||
| def _pyfunc_worker_exec(index, op_id, mapping, lock, record, *args): | |||
| def _pyfunc_worker_exec(index, *args): | |||
| """ | |||
| Internal function for call certain pyfunc in python process. | |||
| """ | |||
| # Some threads in multiprocess.pool can't process sigint signal, | |||
| # and will occur hang problem, so ctrl+c will pass to parent process. | |||
| signal.signal(signal.SIGINT, signal.SIG_IGN) | |||
| if record: | |||
| pid = os.getpid() | |||
| with lock: | |||
| data = mapping[op_id] | |||
| data[1].add(pid) | |||
| mapping[op_id] = data | |||
| return _GLOBAL_PYFUNC_LIST[index](*args) | |||
| @@ -2223,40 +2221,20 @@ class _PythonCallable: | |||
| Internal Python function wrapper for multiprocessing pyfunc. | |||
| """ | |||
| def __init__(self, py_callable, idx, op_id, mapping, lock, worker_num, pool=None): | |||
| def __init__(self, py_callable, idx, pool=None): | |||
| # Original Python callable from user. | |||
| self.py_callable = py_callable | |||
| # Process pool created for current iterator. | |||
| self.pool = pool | |||
| # Python callable index for subprocess _GLOBAL_PYFUNC_LIST | |||
| self.idx = idx | |||
| self.op_id = op_id | |||
| self.mapping = mapping | |||
| self.lock = lock | |||
| self.worker_num = worker_num | |||
| self.record = True | |||
| self.mapping[op_id] = [self.worker_num, set()] | |||
| global _OP_PROCESS, _LOCK | |||
| with _LOCK: | |||
| _OP_PROCESS.update(self.mapping) | |||
| def __call__(self, *args): | |||
| if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 | |||
| # This call will send the tensors along with Python callable index to the process pool. | |||
| # Block, yield GIL. Current thread will reacquire GIL once result is returned. | |||
| if self.record: | |||
| result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, self.mapping, self.lock, | |||
| self.record, *args]) | |||
| else: | |||
| result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, None, None, self.record, | |||
| *args]) | |||
| if self.record: | |||
| data = self.mapping | |||
| if len(data[self.op_id][1]) == self.worker_num: | |||
| self.record = False | |||
| global _OP_PROCESS, _LOCK | |||
| with _LOCK: | |||
| _OP_PROCESS.update(data) | |||
| result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) | |||
| # todo this check might be wrong | |||
| while check_iterator_cleanup() is False: | |||
| try: | |||
| @@ -2273,15 +2251,13 @@ class _PythonCallable: | |||
| return self.py_callable(*args) | |||
| def _mp_pool_exit_preprocess(manager=None): | |||
| def _mp_pool_exit_preprocess(): | |||
| if check_iterator_cleanup() is False: | |||
| logger.info("Execution preprocessing process before map exit.") | |||
| # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async | |||
| # applied to the multiprocessing task to prevent multiprocessing from hang when exiting | |||
| _set_iterator_cleanup() | |||
| time.sleep(3) | |||
| if manager is not None: | |||
| manager.shutdown() | |||
| class _ExceptHookHandler: | |||
| @@ -2385,26 +2361,29 @@ class MapDataset(Dataset): | |||
| # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses | |||
| self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, | |||
| initializer=_pyfunc_worker_init, initargs=(callable_list,)) | |||
| # Pass #2 | |||
| global _OP_NAME | |||
| op_id = _OP_NAME[str(self)] | |||
| idx = 0 | |||
| _manager = multiprocessing.Manager() | |||
| _op_process = _manager.dict() | |||
| _process_lock = _manager.Lock() | |||
| global _OP_NAME, _OP_PROCESS, _LOCK | |||
| op_id = _OP_NAME[str(self)] | |||
| # obtain process id from multiprocessing.pool | |||
| process_id = {op_id: [self.num_parallel_workers, set()]} | |||
| for pool in self.process_pool._pool: # pylint: disable=W0212 | |||
| process_id[op_id][1].add(pool.pid) | |||
| with _LOCK: | |||
| _OP_PROCESS.update(process_id) | |||
| for op in self.operations: | |||
| # our c transforms is now callable and should not be run in python multithreading | |||
| if callable(op) and str(op).find("c_transform") < 0: | |||
| # Wrap Python callable into _PythonCallable | |||
| iter_specific_operations.append(_PythonCallable(op, idx, op_id, _op_process, _process_lock, | |||
| self.num_parallel_workers, self.process_pool)) | |||
| iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) | |||
| idx += 1 | |||
| else: | |||
| # CPP ops remain the same | |||
| iter_specific_operations.append(op) | |||
| self.operations = iter_specific_operations | |||
| self.hook = _ExceptHookHandler() | |||
| atexit.register(_mp_pool_exit_preprocess, _manager) | |||
| atexit.register(_mp_pool_exit_preprocess) | |||
| def __del__(self): | |||
| if hasattr(self, 'process_pool') and self.process_pool is not None: | |||