|
|
|
@@ -3239,21 +3239,19 @@ def _cpp_sampler_fn(sampler, dataset): |
|
|
|
yield tuple([np.array(x, copy=False) for x in val]) |
|
|
|
|
|
|
|
|
|
|
|
def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process): |
|
|
|
def _cpp_sampler_fn_mp(sampler, sample_fn): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with cpp sampler. |
|
|
|
""" |
|
|
|
indices = sampler.get_indices() |
|
|
|
sample_fn = SamplerFn(dataset, num_worker, multi_process) |
|
|
|
return sample_fn.process(indices) |
|
|
|
|
|
|
|
|
|
|
|
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process): |
|
|
|
def _py_sampler_fn_mp(sampler, num_samples, sample_fn): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with Python sampler. |
|
|
|
""" |
|
|
|
indices = _fetch_py_sampler_indices(sampler, num_samples) |
|
|
|
sample_fn = SamplerFn(dataset, num_worker, multi_process) |
|
|
|
return sample_fn.process(indices) |
|
|
|
|
|
|
|
|
|
|
|
@@ -3299,17 +3297,21 @@ class SamplerFn: |
|
|
|
self.multi_process = multi_process |
|
|
|
# Event for end of epoch |
|
|
|
if multi_process is True: |
|
|
|
self.eoe = multiprocessing.Event() |
|
|
|
self.eof = multiprocessing.Event() |
|
|
|
else: |
|
|
|
self.eoe = threading.Event() |
|
|
|
self.eof = threading.Event() |
|
|
|
# Create workers |
|
|
|
for _ in range(num_worker): |
|
|
|
if multi_process is True: |
|
|
|
worker = _GeneratorWorkerMp(dataset, self.eoe) |
|
|
|
worker = _GeneratorWorkerMp(dataset, self.eof) |
|
|
|
worker.daemon = True |
|
|
|
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess, |
|
|
|
# which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase. |
|
|
|
# In this phase, the main process is not locked. |
|
|
|
worker.start() |
|
|
|
else: |
|
|
|
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) |
|
|
|
worker.daemon = True |
|
|
|
worker = _GeneratorWorkerMt(dataset, self.eof) |
|
|
|
worker.daemon = True |
|
|
|
self.workers.append(worker) |
|
|
|
|
|
|
|
def process(self, indices): |
|
|
|
@@ -3317,14 +3319,18 @@ class SamplerFn: |
|
|
|
The main process, start the child process or child thread, and fill the index queue. |
|
|
|
Get the result and return. |
|
|
|
""" |
|
|
|
for w in self.workers: |
|
|
|
# Check whether the queue of the subprocess is empty. |
|
|
|
if not w.queue_empty(): |
|
|
|
raise Exception("The queue of the subprocess is not empty.") |
|
|
|
# Start all workers |
|
|
|
if not w.is_alive(): |
|
|
|
w.start() |
|
|
|
|
|
|
|
# Fill initial index queues |
|
|
|
idx_cursor = 0 |
|
|
|
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) |
|
|
|
|
|
|
|
# Start all workers |
|
|
|
for w in self.workers: |
|
|
|
w.start() |
|
|
|
|
|
|
|
# Fetch results |
|
|
|
for i in range(len(indices)): |
|
|
|
# Fetch result and put index |
|
|
|
@@ -3340,64 +3346,31 @@ class SamplerFn: |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
if idx_cursor < len(indices): |
|
|
|
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) |
|
|
|
# Set end-of-epoch (eoe) event once all indices are sent |
|
|
|
if idx_cursor == len(indices) and not self.eoe.is_set(): |
|
|
|
self.eoe.set() |
|
|
|
yield tuple([np.array(x, copy=False) for x in result]) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
self.eoe.set() |
|
|
|
if self.multi_process is False: |
|
|
|
self.eof.set() |
|
|
|
for w in self.workers: |
|
|
|
w.join() |
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe): |
|
|
|
""" |
|
|
|
Multiprocessing generator worker process loop |
|
|
|
""" |
|
|
|
while True: |
|
|
|
# Fetch index, block |
|
|
|
try: |
|
|
|
idx = idx_queue.get() |
|
|
|
except KeyboardInterrupt: |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
if idx is None: |
|
|
|
# When the queue is out of scope from master process, a None item can be fetched from the queue. |
|
|
|
# Upon receiving None, worker process should check if EOE is set. |
|
|
|
assert eoe.is_set(), "" |
|
|
|
return |
|
|
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process |
|
|
|
result = dataset[idx] |
|
|
|
# Send data, block |
|
|
|
try: |
|
|
|
result_queue.put(result) |
|
|
|
except KeyboardInterrupt: |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
del result, idx |
|
|
|
self.eof.set() |
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): |
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eof): |
|
|
|
""" |
|
|
|
Multithread generator worker process loop. |
|
|
|
Multithread or multiprocess generator worker process loop. |
|
|
|
""" |
|
|
|
while True: |
|
|
|
# Fetch index, block |
|
|
|
try: |
|
|
|
# Index is generated very fast, so the timeout is very short |
|
|
|
idx = idx_queue.get(timeout=0.01) |
|
|
|
idx = idx_queue.get(timeout=1) |
|
|
|
except KeyboardInterrupt: |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
except queue.Empty: |
|
|
|
if eof.is_set() or eoe.is_set(): |
|
|
|
if eof.is_set(): |
|
|
|
return |
|
|
|
# If end-of-epoch (eoe) or end-of-file (eof) is not set, continue to get data from idx_queue |
|
|
|
# If end-of-file (eof) is not set, continue to get data from idx_queue |
|
|
|
continue |
|
|
|
if idx is None: |
|
|
|
# When the queue is out of scope from master process, a None item can be fetched from the queue. |
|
|
|
# Upon receiving None, worker process should check if EOE is set. |
|
|
|
assert eoe.is_set(), "" |
|
|
|
# Upon receiving None, worker process should check if eof is set. |
|
|
|
assert eof.is_set(), "" |
|
|
|
return |
|
|
|
if eof.is_set(): |
|
|
|
return |
|
|
|
@@ -3416,8 +3389,6 @@ def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): |
|
|
|
continue |
|
|
|
break |
|
|
|
del result, idx |
|
|
|
if eoe.is_set() and idx_queue.empty(): |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
class _GeneratorWorkerMt(threading.Thread): |
|
|
|
@@ -3425,10 +3396,10 @@ class _GeneratorWorkerMt(threading.Thread): |
|
|
|
Worker process for multithread Generator. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset, eoe, eof): |
|
|
|
def __init__(self, dataset, eof): |
|
|
|
self.idx_queue = queue.Queue(16) |
|
|
|
self.res_queue = queue.Queue(16) |
|
|
|
super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) |
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) |
|
|
|
|
|
|
|
def put(self, item): |
|
|
|
""" |
|
|
|
@@ -3442,16 +3413,25 @@ class _GeneratorWorkerMt(threading.Thread): |
|
|
|
""" |
|
|
|
return self.res_queue.get(timeout=30) |
|
|
|
|
|
|
|
def queue_empty(self): |
|
|
|
if not self.idx_queue.empty(): |
|
|
|
logger.error("idx_queue is not empty") |
|
|
|
return False |
|
|
|
if not self.res_queue.empty(): |
|
|
|
logger.error("res_queue is not empty") |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
class _GeneratorWorkerMp(multiprocessing.Process): |
|
|
|
""" |
|
|
|
Worker process for multiprocess Generator. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset, eoe): |
|
|
|
def __init__(self, dataset, eof): |
|
|
|
self.idx_queue = multiprocessing.Queue(16) |
|
|
|
self.res_queue = multiprocessing.Queue(16) |
|
|
|
super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe)) |
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) |
|
|
|
|
|
|
|
def put(self, item): |
|
|
|
""" |
|
|
|
@@ -3467,6 +3447,15 @@ class _GeneratorWorkerMp(multiprocessing.Process): |
|
|
|
# when we run too many iterators with infinite epoch(num_epoch=-1) |
|
|
|
return self.res_queue.get(timeout=30) |
|
|
|
|
|
|
|
def queue_empty(self): |
|
|
|
if not self.idx_queue.empty(): |
|
|
|
logger.error("idx_queue is not empty") |
|
|
|
return False |
|
|
|
if not self.res_queue.empty(): |
|
|
|
logger.error("res_queue is not empty") |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
# Try to destruct here, sometimes the class itself will be destructed in advance, |
|
|
|
# so "self" will be a NoneType |
|
|
|
@@ -3657,16 +3646,14 @@ class GeneratorDataset(MappableDataset): |
|
|
|
sampler_instance.set_num_rows(len(self.source)) |
|
|
|
sampler_instance.initialize() |
|
|
|
if new_op.num_parallel_workers > 1: |
|
|
|
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, |
|
|
|
new_op.num_parallel_workers, |
|
|
|
self.python_multiprocessing)) |
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) |
|
|
|
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn)) |
|
|
|
else: |
|
|
|
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) |
|
|
|
else: |
|
|
|
if new_op.num_parallel_workers > 1: |
|
|
|
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, |
|
|
|
new_op.num_parallel_workers, |
|
|
|
self.python_multiprocessing)) |
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) |
|
|
|
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn)) |
|
|
|
else: |
|
|
|
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) |
|
|
|
else: |
|
|
|
|