| @@ -3279,14 +3279,13 @@ class SamplerFn: | |||
| # Event for end of epoch | |||
| if multi_process is True: | |||
| self.eoe = multiprocessing.Event() | |||
| self.eof = multiprocessing.Event() | |||
| else: | |||
| self.eoe = threading.Event() | |||
| self.eof = threading.Event() | |||
| # Create workers | |||
| for _ in range(num_worker): | |||
| if multi_process is True: | |||
| worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof) | |||
| worker = _GeneratorWorkerMp(dataset, self.eoe) | |||
| else: | |||
| worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) | |||
| worker.daemon = True | |||
| @@ -3327,15 +3326,40 @@ class SamplerFn: | |||
| def __del__(self): | |||
| self.eoe.set() | |||
| self.eof.set() | |||
| if self.multi_process is False: | |||
| self.eof.set() | |||
| for w in self.workers: | |||
| w.join() | |||
| def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): | |||
| def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe): | |||
| """ | |||
| Multiprocessing or multithread generator worker process loop. | |||
| Multiprocessing generator worker process loop | |||
| """ | |||
| while True: | |||
| # Fetch index, block | |||
| try: | |||
| idx = idx_queue.get() | |||
| except KeyboardInterrupt: | |||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||
| if idx is None: | |||
| # When the queue is out of scope from master process, a None item can be fetched from the queue. | |||
| # Upon receiving None, worker process should check if EOE is set. | |||
| assert eoe.is_set(), "" | |||
| return | |||
| # Fetch data, any exception from __getitem__ will terminate worker and timeout master process | |||
| result = dataset[idx] | |||
| # Send data, block | |||
| try: | |||
| result_queue.put(result) | |||
| except KeyboardInterrupt: | |||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||
| del result, idx | |||
| def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): | |||
| """ | |||
| Multithread generator worker process loop. | |||
| """ | |||
| while True: | |||
| # Fetch index, block | |||
| @@ -3383,7 +3407,7 @@ class _GeneratorWorkerMt(threading.Thread): | |||
| def __init__(self, dataset, eoe, eof): | |||
| self.idx_queue = queue.Queue(16) | |||
| self.res_queue = queue.Queue(16) | |||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||
| super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||
| def put(self, item): | |||
| """ | |||
| @@ -3403,10 +3427,10 @@ class _GeneratorWorkerMp(multiprocessing.Process): | |||
| Worker process for multiprocess Generator. | |||
| """ | |||
| def __init__(self, dataset, eoe, eof): | |||
| def __init__(self, dataset, eoe): | |||
| self.idx_queue = multiprocessing.Queue(16) | |||
| self.res_queue = multiprocessing.Queue(16) | |||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||
| super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe)) | |||
| def put(self, item): | |||
| """ | |||