Browse Source

!7600 fix GeneratorDataset timeout

Merge pull request !7600 from heleiwang/fix_generator
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0218f8a06f
1 changed files with 53 additions and 66 deletions
  1. +53
    -66
      mindspore/dataset/engine/datasets.py

+ 53
- 66
mindspore/dataset/engine/datasets.py View File

@@ -3239,21 +3239,19 @@ def _cpp_sampler_fn(sampler, dataset):
yield tuple([np.array(x, copy=False) for x in val]) yield tuple([np.array(x, copy=False) for x in val])




def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process):
def _cpp_sampler_fn_mp(sampler, sample_fn):
""" """
Multiprocessing generator function wrapper for mappable dataset with cpp sampler. Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
""" """
indices = sampler.get_indices() indices = sampler.get_indices()
sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices) return sample_fn.process(indices)




def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process):
def _py_sampler_fn_mp(sampler, num_samples, sample_fn):
""" """
Multiprocessing generator function wrapper for mappable dataset with Python sampler. Multiprocessing generator function wrapper for mappable dataset with Python sampler.
""" """
indices = _fetch_py_sampler_indices(sampler, num_samples) indices = _fetch_py_sampler_indices(sampler, num_samples)
sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices) return sample_fn.process(indices)




@@ -3299,17 +3297,21 @@ class SamplerFn:
self.multi_process = multi_process self.multi_process = multi_process
# Event for end of epoch # Event for end of epoch
if multi_process is True: if multi_process is True:
self.eoe = multiprocessing.Event()
self.eof = multiprocessing.Event()
else: else:
self.eoe = threading.Event()
self.eof = threading.Event() self.eof = threading.Event()
# Create workers # Create workers
for _ in range(num_worker): for _ in range(num_worker):
if multi_process is True: if multi_process is True:
worker = _GeneratorWorkerMp(dataset, self.eoe)
worker = _GeneratorWorkerMp(dataset, self.eof)
worker.daemon = True
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
# which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase.
# In this phase, the main process is not locked.
worker.start()
else: else:
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
worker.daemon = True
worker = _GeneratorWorkerMt(dataset, self.eof)
worker.daemon = True
self.workers.append(worker) self.workers.append(worker)


def process(self, indices): def process(self, indices):
@@ -3317,14 +3319,18 @@ class SamplerFn:
The main process, start the child process or child thread, and fill the index queue. The main process, start the child process or child thread, and fill the index queue.
Get the result and return. Get the result and return.
""" """
for w in self.workers:
# Check whether the queue of the subprocess is empty.
if not w.queue_empty():
raise Exception("The queue of the subprocess is not empty.")
# Start all workers
if not w.is_alive():
w.start()

# Fill initial index queues # Fill initial index queues
idx_cursor = 0 idx_cursor = 0
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)


# Start all workers
for w in self.workers:
w.start()

# Fetch results # Fetch results
for i in range(len(indices)): for i in range(len(indices)):
# Fetch result and put index # Fetch result and put index
@@ -3340,64 +3346,31 @@ class SamplerFn:
raise Exception("Generator worker receives KeyboardInterrupt") raise Exception("Generator worker receives KeyboardInterrupt")
if idx_cursor < len(indices): if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Set end-of-epoch (eoe) event once all indices are sent
if idx_cursor == len(indices) and not self.eoe.is_set():
self.eoe.set()
yield tuple([np.array(x, copy=False) for x in result]) yield tuple([np.array(x, copy=False) for x in result])


def __del__(self): def __del__(self):
self.eoe.set()
if self.multi_process is False:
self.eof.set()
for w in self.workers:
w.join()


def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe):
"""
Multiprocessing generator worker process loop
"""
while True:
# Fetch index, block
try:
idx = idx_queue.get()
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), ""
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx]
# Send data, block
try:
result_queue.put(result)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
del result, idx
self.eof.set()




def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
""" """
Multithread generator worker process loop.
Multithread or multiprocess generator worker process loop.
""" """
while True: while True:
# Fetch index, block # Fetch index, block
try: try:
# Index is generated very fast, so the timeout is very short
idx = idx_queue.get(timeout=0.01)
idx = idx_queue.get(timeout=1)
except KeyboardInterrupt: except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt") raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Empty: except queue.Empty:
if eof.is_set() or eoe.is_set():
if eof.is_set():
return return
# If end-of-epoch (eoe) or end-of-file (eof) is not set, continue to get data from idx_queue
# If end-of-file (eof) is not set, continue to get data from idx_queue
continue continue
if idx is None: if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue. # When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), ""
# Upon receiving None, worker process should check if eof is set.
assert eof.is_set(), ""
return return
if eof.is_set(): if eof.is_set():
return return
@@ -3416,8 +3389,6 @@ def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
continue continue
break break
del result, idx del result, idx
if eoe.is_set() and idx_queue.empty():
return




class _GeneratorWorkerMt(threading.Thread): class _GeneratorWorkerMt(threading.Thread):
@@ -3425,10 +3396,10 @@ class _GeneratorWorkerMt(threading.Thread):
Worker process for multithread Generator. Worker process for multithread Generator.
""" """


def __init__(self, dataset, eoe, eof):
def __init__(self, dataset, eof):
self.idx_queue = queue.Queue(16) self.idx_queue = queue.Queue(16)
self.res_queue = queue.Queue(16) self.res_queue = queue.Queue(16)
super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))


def put(self, item): def put(self, item):
""" """
@@ -3442,16 +3413,25 @@ class _GeneratorWorkerMt(threading.Thread):
""" """
return self.res_queue.get(timeout=30) return self.res_queue.get(timeout=30)


def queue_empty(self):
if not self.idx_queue.empty():
logger.error("idx_queue is not empty")
return False
if not self.res_queue.empty():
logger.error("res_queue is not empty")
return False
return True



class _GeneratorWorkerMp(multiprocessing.Process): class _GeneratorWorkerMp(multiprocessing.Process):
""" """
Worker process for multiprocess Generator. Worker process for multiprocess Generator.
""" """


def __init__(self, dataset, eoe):
def __init__(self, dataset, eof):
self.idx_queue = multiprocessing.Queue(16) self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16) self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))


def put(self, item): def put(self, item):
""" """
@@ -3467,6 +3447,15 @@ class _GeneratorWorkerMp(multiprocessing.Process):
# when we run too many iterators with infinite epoch(num_epoch=-1) # when we run too many iterators with infinite epoch(num_epoch=-1)
return self.res_queue.get(timeout=30) return self.res_queue.get(timeout=30)


def queue_empty(self):
if not self.idx_queue.empty():
logger.error("idx_queue is not empty")
return False
if not self.res_queue.empty():
logger.error("res_queue is not empty")
return False
return True

def __del__(self): def __del__(self):
# Try to destruct here, sometimes the class itself will be destructed in advance, # Try to destruct here, sometimes the class itself will be destructed in advance,
# so "self" will be a NoneType # so "self" will be a NoneType
@@ -3657,16 +3646,14 @@ class GeneratorDataset(MappableDataset):
sampler_instance.set_num_rows(len(self.source)) sampler_instance.set_num_rows(len(self.source))
sampler_instance.initialize() sampler_instance.initialize()
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn))
else: else:
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
else: else:
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn))
else: else:
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
else: else:


Loading…
Cancel
Save