From 4f946bc54bc08733f4d7551ac829ff706ff4f1d4 Mon Sep 17 00:00:00 2001 From: heleiwang Date: Wed, 21 Oct 2020 19:18:03 +0800 Subject: [PATCH] Modify the processing method of multiple processes in the GeneratorDataset: 1. Start the child process in the init phase 2. At the beginning of each epoch, the child process is not recreated, but the child process created at the beginning is used --- mindspore/dataset/engine/datasets.py | 119 ++++++++++++--------------- 1 file changed, 53 insertions(+), 66 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index a7a8389c67..58ff041d35 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3239,21 +3239,19 @@ def _cpp_sampler_fn(sampler, dataset): yield tuple([np.array(x, copy=False) for x in val]) -def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process): +def _cpp_sampler_fn_mp(sampler, sample_fn): """ Multiprocessing generator function wrapper for mappable dataset with cpp sampler. """ indices = sampler.get_indices() - sample_fn = SamplerFn(dataset, num_worker, multi_process) return sample_fn.process(indices) -def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process): +def _py_sampler_fn_mp(sampler, num_samples, sample_fn): """ Multiprocessing generator function wrapper for mappable dataset with Python sampler. """ indices = _fetch_py_sampler_indices(sampler, num_samples) - sample_fn = SamplerFn(dataset, num_worker, multi_process) return sample_fn.process(indices) @@ -3299,17 +3297,21 @@ class SamplerFn: self.multi_process = multi_process # Event for end of epoch if multi_process is True: - self.eoe = multiprocessing.Event() + self.eof = multiprocessing.Event() else: - self.eoe = threading.Event() self.eof = threading.Event() # Create workers for _ in range(num_worker): if multi_process is True: - worker = _GeneratorWorkerMp(dataset, self.eoe) + worker = _GeneratorWorkerMp(dataset, self.eof) + worker.daemon = True + # When multi processes fork a subprocess, the lock of the main process is copied to the subprocess, + # which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase. + # In this phase, the main process is not locked. + worker.start() else: - worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) - worker.daemon = True + worker = _GeneratorWorkerMt(dataset, self.eof) + worker.daemon = True self.workers.append(worker) def process(self, indices): @@ -3317,14 +3319,18 @@ class SamplerFn: The main process, start the child process or child thread, and fill the index queue. Get the result and return. """ + for w in self.workers: + # Check whether the queue of the subprocess is empty. + if not w.queue_empty(): + raise Exception("The queue of the subprocess is not empty.") + # Start all workers + if not w.is_alive(): + w.start() + # Fill initial index queues idx_cursor = 0 idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) - # Start all workers - for w in self.workers: - w.start() - # Fetch results for i in range(len(indices)): # Fetch result and put index @@ -3340,64 +3346,31 @@ class SamplerFn: raise Exception("Generator worker receives KeyboardInterrupt") if idx_cursor < len(indices): idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) - # Set end-of-epoch (eoe) event once all indices are sent - if idx_cursor == len(indices) and not self.eoe.is_set(): - self.eoe.set() yield tuple([np.array(x, copy=False) for x in result]) def __del__(self): - self.eoe.set() - if self.multi_process is False: - self.eof.set() - for w in self.workers: - w.join() - - -def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe): - """ - Multiprocessing generator worker process loop - """ - while True: - # Fetch index, block - try: - idx = idx_queue.get() - except KeyboardInterrupt: - raise Exception("Generator worker receives KeyboardInterrupt") - if idx is None: - # When the queue is out of scope from master process, a None item can be fetched from the queue. - # Upon receiving None, worker process should check if EOE is set. - assert eoe.is_set(), "" - return - # Fetch data, any exception from __getitem__ will terminate worker and timeout master process - result = dataset[idx] - # Send data, block - try: - result_queue.put(result) - except KeyboardInterrupt: - raise Exception("Generator worker receives KeyboardInterrupt") - del result, idx + self.eof.set() -def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): +def _generator_worker_loop(dataset, idx_queue, result_queue, eof): """ - Multithread generator worker process loop. + Multithread or multiprocess generator worker process loop. """ while True: # Fetch index, block try: - # Index is generated very fast, so the timeout is very short - idx = idx_queue.get(timeout=0.01) + idx = idx_queue.get(timeout=1) except KeyboardInterrupt: raise Exception("Generator worker receives KeyboardInterrupt") except queue.Empty: - if eof.is_set() or eoe.is_set(): + if eof.is_set(): return - # If end-of-epoch (eoe) or end-of-file (eof) is not set, continue to get data from idx_queue + # If end-of-file (eof) is not set, continue to get data from idx_queue continue if idx is None: # When the queue is out of scope from master process, a None item can be fetched from the queue. - # Upon receiving None, worker process should check if EOE is set. - assert eoe.is_set(), "" + # Upon receiving None, worker process should check if eof is set. + assert eof.is_set(), "" return if eof.is_set(): return @@ -3416,8 +3389,6 @@ def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): continue break del result, idx - if eoe.is_set() and idx_queue.empty(): - return class _GeneratorWorkerMt(threading.Thread): @@ -3425,10 +3396,10 @@ class _GeneratorWorkerMt(threading.Thread): Worker process for multithread Generator. """ - def __init__(self, dataset, eoe, eof): + def __init__(self, dataset, eof): self.idx_queue = queue.Queue(16) self.res_queue = queue.Queue(16) - super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) def put(self, item): """ @@ -3442,16 +3413,25 @@ class _GeneratorWorkerMt(threading.Thread): """ return self.res_queue.get(timeout=30) + def queue_empty(self): + if not self.idx_queue.empty(): + logger.error("idx_queue is not empty") + return False + if not self.res_queue.empty(): + logger.error("res_queue is not empty") + return False + return True + class _GeneratorWorkerMp(multiprocessing.Process): """ Worker process for multiprocess Generator. """ - def __init__(self, dataset, eoe): + def __init__(self, dataset, eof): self.idx_queue = multiprocessing.Queue(16) self.res_queue = multiprocessing.Queue(16) - super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe)) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) def put(self, item): """ @@ -3467,6 +3447,15 @@ class _GeneratorWorkerMp(multiprocessing.Process): # when we run too many iterators with infinite epoch(num_epoch=-1) return self.res_queue.get(timeout=30) + def queue_empty(self): + if not self.idx_queue.empty(): + logger.error("idx_queue is not empty") + return False + if not self.res_queue.empty(): + logger.error("res_queue is not empty") + return False + return True + def __del__(self): # Try to destruct here, sometimes the class itself will be destructed in advance, # so "self" will be a NoneType @@ -3657,16 +3646,14 @@ class GeneratorDataset(MappableDataset): sampler_instance.set_num_rows(len(self.source)) sampler_instance.initialize() if new_op.num_parallel_workers > 1: - new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, - new_op.num_parallel_workers, - self.python_multiprocessing)) + sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) + new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn)) else: new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) else: if new_op.num_parallel_workers > 1: - new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, - new_op.num_parallel_workers, - self.python_multiprocessing)) + sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) + new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn)) else: new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) else: