| @@ -3279,14 +3279,13 @@ class SamplerFn: | |||||
| # Event for end of epoch | # Event for end of epoch | ||||
| if multi_process is True: | if multi_process is True: | ||||
| self.eoe = multiprocessing.Event() | self.eoe = multiprocessing.Event() | ||||
| self.eof = multiprocessing.Event() | |||||
| else: | else: | ||||
| self.eoe = threading.Event() | self.eoe = threading.Event() | ||||
| self.eof = threading.Event() | self.eof = threading.Event() | ||||
| # Create workers | # Create workers | ||||
| for _ in range(num_worker): | for _ in range(num_worker): | ||||
| if multi_process is True: | if multi_process is True: | ||||
| worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof) | |||||
| worker = _GeneratorWorkerMp(dataset, self.eoe) | |||||
| else: | else: | ||||
| worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) | worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) | ||||
| worker.daemon = True | worker.daemon = True | ||||
| @@ -3327,15 +3326,40 @@ class SamplerFn: | |||||
| def __del__(self): | def __del__(self): | ||||
| self.eoe.set() | self.eoe.set() | ||||
| self.eof.set() | |||||
| if self.multi_process is False: | if self.multi_process is False: | ||||
| self.eof.set() | |||||
| for w in self.workers: | for w in self.workers: | ||||
| w.join() | w.join() | ||||
| def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): | |||||
| def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe): | |||||
| """ | """ | ||||
| Multiprocessing or multithread generator worker process loop. | |||||
| Multiprocessing generator worker process loop | |||||
| """ | |||||
| while True: | |||||
| # Fetch index, block | |||||
| try: | |||||
| idx = idx_queue.get() | |||||
| except KeyboardInterrupt: | |||||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||||
| if idx is None: | |||||
| # When the queue is out of scope from master process, a None item can be fetched from the queue. | |||||
| # Upon receiving None, worker process should check if EOE is set. | |||||
| assert eoe.is_set(), "" | |||||
| return | |||||
| # Fetch data, any exception from __getitem__ will terminate worker and timeout master process | |||||
| result = dataset[idx] | |||||
| # Send data, block | |||||
| try: | |||||
| result_queue.put(result) | |||||
| except KeyboardInterrupt: | |||||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||||
| del result, idx | |||||
| def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): | |||||
| """ | |||||
| Multithread generator worker process loop. | |||||
| """ | """ | ||||
| while True: | while True: | ||||
| # Fetch index, block | # Fetch index, block | ||||
| @@ -3383,7 +3407,7 @@ class _GeneratorWorkerMt(threading.Thread): | |||||
| def __init__(self, dataset, eoe, eof): | def __init__(self, dataset, eoe, eof): | ||||
| self.idx_queue = queue.Queue(16) | self.idx_queue = queue.Queue(16) | ||||
| self.res_queue = queue.Queue(16) | self.res_queue = queue.Queue(16) | ||||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||||
| super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||||
| def put(self, item): | def put(self, item): | ||||
| """ | """ | ||||
| @@ -3403,10 +3427,10 @@ class _GeneratorWorkerMp(multiprocessing.Process): | |||||
| Worker process for multiprocess Generator. | Worker process for multiprocess Generator. | ||||
| """ | """ | ||||
| def __init__(self, dataset, eoe, eof): | |||||
| def __init__(self, dataset, eoe): | |||||
| self.idx_queue = multiprocessing.Queue(16) | self.idx_queue = multiprocessing.Queue(16) | ||||
| self.res_queue = multiprocessing.Queue(16) | self.res_queue = multiprocessing.Queue(16) | ||||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||||
| super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe)) | |||||
| def put(self, item): | def put(self, item): | ||||
| """ | """ | ||||