|
|
@@ -3279,14 +3279,13 @@ class SamplerFn: |
|
|
# Event for end of epoch |
|
|
# Event for end of epoch |
|
|
if multi_process is True: |
|
|
if multi_process is True: |
|
|
self.eoe = multiprocessing.Event() |
|
|
self.eoe = multiprocessing.Event() |
|
|
self.eof = multiprocessing.Event() |
|
|
|
|
|
else: |
|
|
else: |
|
|
self.eoe = threading.Event() |
|
|
self.eoe = threading.Event() |
|
|
self.eof = threading.Event() |
|
|
self.eof = threading.Event() |
|
|
# Create workers |
|
|
# Create workers |
|
|
for _ in range(num_worker): |
|
|
for _ in range(num_worker): |
|
|
if multi_process is True: |
|
|
if multi_process is True: |
|
|
worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof) |
|
|
|
|
|
|
|
|
worker = _GeneratorWorkerMp(dataset, self.eoe) |
|
|
else: |
|
|
else: |
|
|
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) |
|
|
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) |
|
|
worker.daemon = True |
|
|
worker.daemon = True |
|
|
@@ -3327,15 +3326,40 @@ class SamplerFn: |
|
|
|
|
|
|
|
|
def __del__(self): |
|
|
def __del__(self): |
|
|
self.eoe.set() |
|
|
self.eoe.set() |
|
|
self.eof.set() |
|
|
|
|
|
if self.multi_process is False: |
|
|
if self.multi_process is False: |
|
|
|
|
|
self.eof.set() |
|
|
for w in self.workers: |
|
|
for w in self.workers: |
|
|
w.join() |
|
|
w.join() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): |
|
|
|
|
|
|
|
|
def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe): |
|
|
""" |
|
|
""" |
|
|
Multiprocessing or multithread generator worker process loop. |
|
|
|
|
|
|
|
|
Multiprocessing generator worker process loop |
|
|
|
|
|
""" |
|
|
|
|
|
while True: |
|
|
|
|
|
# Fetch index, block |
|
|
|
|
|
try: |
|
|
|
|
|
idx = idx_queue.get() |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
|
|
if idx is None: |
|
|
|
|
|
# When the queue is out of scope from master process, a None item can be fetched from the queue. |
|
|
|
|
|
# Upon receiving None, worker process should check if EOE is set. |
|
|
|
|
|
assert eoe.is_set(), "" |
|
|
|
|
|
return |
|
|
|
|
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process |
|
|
|
|
|
result = dataset[idx] |
|
|
|
|
|
# Send data, block |
|
|
|
|
|
try: |
|
|
|
|
|
result_queue.put(result) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
|
|
del result, idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): |
|
|
|
|
|
""" |
|
|
|
|
|
Multithread generator worker process loop. |
|
|
""" |
|
|
""" |
|
|
while True: |
|
|
while True: |
|
|
# Fetch index, block |
|
|
# Fetch index, block |
|
|
@@ -3383,7 +3407,7 @@ class _GeneratorWorkerMt(threading.Thread): |
|
|
def __init__(self, dataset, eoe, eof): |
|
|
def __init__(self, dataset, eoe, eof): |
|
|
self.idx_queue = queue.Queue(16) |
|
|
self.idx_queue = queue.Queue(16) |
|
|
self.res_queue = queue.Queue(16) |
|
|
self.res_queue = queue.Queue(16) |
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) |
|
|
|
|
|
|
|
|
super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) |
|
|
|
|
|
|
|
|
def put(self, item): |
|
|
def put(self, item): |
|
|
""" |
|
|
""" |
|
|
@@ -3403,10 +3427,10 @@ class _GeneratorWorkerMp(multiprocessing.Process): |
|
|
Worker process for multiprocess Generator. |
|
|
Worker process for multiprocess Generator. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, dataset, eoe, eof): |
|
|
|
|
|
|
|
|
def __init__(self, dataset, eoe): |
|
|
self.idx_queue = multiprocessing.Queue(16) |
|
|
self.idx_queue = multiprocessing.Queue(16) |
|
|
self.res_queue = multiprocessing.Queue(16) |
|
|
self.res_queue = multiprocessing.Queue(16) |
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) |
|
|
|
|
|
|
|
|
super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe)) |
|
|
|
|
|
|
|
|
def put(self, item): |
|
|
def put(self, item): |
|
|
""" |
|
|
""" |
|
|
|