@@ -3239,21 +3239,19 @@ def _cpp_sampler_fn(sampler, dataset):
yield tuple([np.array(x, copy=False) for x in val])
yield tuple([np.array(x, copy=False) for x in val])
def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process ):
def _cpp_sampler_fn_mp(sampler, sample_fn ):
"""
"""
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
"""
"""
indices = sampler.get_indices()
indices = sampler.get_indices()
sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices)
return sample_fn.process(indices)
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process ):
def _py_sampler_fn_mp(sampler, num_samples, sample_fn ):
"""
"""
Multiprocessing generator function wrapper for mappable dataset with Python sampler.
Multiprocessing generator function wrapper for mappable dataset with Python sampler.
"""
"""
indices = _fetch_py_sampler_indices(sampler, num_samples)
indices = _fetch_py_sampler_indices(sampler, num_samples)
sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices)
return sample_fn.process(indices)
@@ -3299,17 +3297,21 @@ class SamplerFn:
self.multi_process = multi_process
self.multi_process = multi_process
# Event for end of epoch
# Event for end of epoch
if multi_process is True:
if multi_process is True:
self.eoe = multiprocessing.Event()
self.eof = multiprocessing.Event()
else:
else:
self.eoe = threading.Event()
self.eof = threading.Event()
self.eof = threading.Event()
# Create workers
# Create workers
for _ in range(num_worker):
for _ in range(num_worker):
if multi_process is True:
if multi_process is True:
worker = _GeneratorWorkerMp(dataset, self.eoe)
worker = _GeneratorWorkerMp(dataset, self.eof)
worker.daemon = True
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
# which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase.
# In this phase, the main process is not locked.
worker.start()
else:
else:
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
worker.daemon = True
worker = _GeneratorWorkerMt(dataset, self.eof)
worker.daemon = True
self.workers.append(worker)
self.workers.append(worker)
def process(self, indices):
def process(self, indices):
@@ -3317,14 +3319,18 @@ class SamplerFn:
The main process, start the child process or child thread, and fill the index queue.
The main process, start the child process or child thread, and fill the index queue.
Get the result and return.
Get the result and return.
"""
"""
for w in self.workers:
# Check whether the queue of the subprocess is empty.
if not w.queue_empty():
raise Exception("The queue of the subprocess is not empty.")
# Start all workers
if not w.is_alive():
w.start()
# Fill initial index queues
# Fill initial index queues
idx_cursor = 0
idx_cursor = 0
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Start all workers
for w in self.workers:
w.start()
# Fetch results
# Fetch results
for i in range(len(indices)):
for i in range(len(indices)):
# Fetch result and put index
# Fetch result and put index
@@ -3340,64 +3346,31 @@ class SamplerFn:
raise Exception("Generator worker receives KeyboardInterrupt")
raise Exception("Generator worker receives KeyboardInterrupt")
if idx_cursor < len(indices):
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Set end-of-epoch (eoe) event once all indices are sent
if idx_cursor == len(indices) and not self.eoe.is_set():
self.eoe.set()
yield tuple([np.array(x, copy=False) for x in result])
yield tuple([np.array(x, copy=False) for x in result])
def __del__(self):
def __del__(self):
self.eoe.set()
if self.multi_process is False:
self.eof.set()
for w in self.workers:
w.join()
def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe):
"""
Multiprocessing generator worker process loop
"""
while True:
# Fetch index, block
try:
idx = idx_queue.get()
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), ""
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx]
# Send data, block
try:
result_queue.put(result)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
del result, idx
self.eof.set()
def _generator_worker_loop_mt (dataset, idx_queue, result_queue, eo e, eof):
def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
"""
"""
Multithread generator worker process loop.
Multithread or multiprocess generator worker process loop.
"""
"""
while True:
while True:
# Fetch index, block
# Fetch index, block
try:
try:
# Index is generated very fast, so the timeout is very short
idx = idx_queue.get(timeout=0.01)
idx = idx_queue.get(timeout=1)
except KeyboardInterrupt:
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Empty:
except queue.Empty:
if eof.is_set() or eoe.is_set() :
if eof.is_set():
return
return
# If end-of-epoch (eoe) or end-of- file (eof) is not set, continue to get data from idx_queue
# If end-of-file (eof) is not set, continue to get data from idx_queue
continue
continue
if idx is None:
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe .is_set(), ""
# Upon receiving None, worker process should check if eof is set.
assert eof.is_set(), ""
return
return
if eof.is_set():
if eof.is_set():
return
return
@@ -3416,8 +3389,6 @@ def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
continue
continue
break
break
del result, idx
del result, idx
if eoe.is_set() and idx_queue.empty():
return
class _GeneratorWorkerMt(threading.Thread):
class _GeneratorWorkerMt(threading.Thread):
@@ -3425,10 +3396,10 @@ class _GeneratorWorkerMt(threading.Thread):
Worker process for multithread Generator.
Worker process for multithread Generator.
"""
"""
def __init__(self, dataset, eoe, eo f):
def __init__(self, dataset, eof):
self.idx_queue = queue.Queue(16)
self.idx_queue = queue.Queue(16)
self.res_queue = queue.Queue(16)
self.res_queue = queue.Queue(16)
super().__init__(target=_generator_worker_loop_mt , args=(dataset, self.idx_queue, self.res_queue, eo e, eof))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
def put(self, item):
def put(self, item):
"""
"""
@@ -3442,16 +3413,25 @@ class _GeneratorWorkerMt(threading.Thread):
"""
"""
return self.res_queue.get(timeout=30)
return self.res_queue.get(timeout=30)
def queue_empty(self):
if not self.idx_queue.empty():
logger.error("idx_queue is not empty")
return False
if not self.res_queue.empty():
logger.error("res_queue is not empty")
return False
return True
class _GeneratorWorkerMp(multiprocessing.Process):
class _GeneratorWorkerMp(multiprocessing.Process):
"""
"""
Worker process for multiprocess Generator.
Worker process for multiprocess Generator.
"""
"""
def __init__(self, dataset, eoe):
def __init__(self, dataset, eof ):
self.idx_queue = multiprocessing.Queue(16)
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop_mp , args=(dataset, self.idx_queue, self.res_queue, eoe ))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof ))
def put(self, item):
def put(self, item):
"""
"""
@@ -3467,6 +3447,15 @@ class _GeneratorWorkerMp(multiprocessing.Process):
# when we run too many iterators with infinite epoch(num_epoch=-1)
# when we run too many iterators with infinite epoch(num_epoch=-1)
return self.res_queue.get(timeout=30)
return self.res_queue.get(timeout=30)
def queue_empty(self):
if not self.idx_queue.empty():
logger.error("idx_queue is not empty")
return False
if not self.res_queue.empty():
logger.error("res_queue is not empty")
return False
return True
def __del__(self):
def __del__(self):
# Try to destruct here, sometimes the class itself will be destructed in advance,
# Try to destruct here, sometimes the class itself will be destructed in advance,
# so "self" will be a NoneType
# so "self" will be a NoneType
@@ -3657,16 +3646,14 @@ class GeneratorDataset(MappableDataset):
sampler_instance.set_num_rows(len(self.source))
sampler_instance.set_num_rows(len(self.source))
sampler_instance.initialize()
sampler_instance.initialize()
if new_op.num_parallel_workers > 1:
if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn))
else:
else:
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
else:
else:
if new_op.num_parallel_workers > 1:
if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn))
else:
else:
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
else:
else: