Browse Source

Fix timeout of GeneratorDataset multiprocessing

tags/v1.1.0
YangLuo 5 years ago
parent
commit
977c41db01
1 changed files with 32 additions and 8 deletions
  1. +32
    -8
      mindspore/dataset/engine/datasets.py

+ 32
- 8
mindspore/dataset/engine/datasets.py View File

@@ -3279,14 +3279,13 @@ class SamplerFn:
# Event for end of epoch
if multi_process is True:
self.eoe = multiprocessing.Event()
self.eof = multiprocessing.Event()
else:
self.eoe = threading.Event()
self.eof = threading.Event()
# Create workers
for _ in range(num_worker):
if multi_process is True:
worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof)
worker = _GeneratorWorkerMp(dataset, self.eoe)
else:
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
worker.daemon = True
@@ -3327,15 +3326,40 @@ class SamplerFn:

def __del__(self):
self.eoe.set()
self.eof.set()
if self.multi_process is False:
self.eof.set()
for w in self.workers:
w.join()


def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe):
"""
Multiprocessing or multithread generator worker process loop.
Multiprocessing generator worker process loop
"""
while True:
# Fetch index, block
try:
idx = idx_queue.get()
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), ""
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx]
# Send data, block
try:
result_queue.put(result)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
del result, idx


def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
"""
Multithread generator worker process loop.
"""
while True:
# Fetch index, block
@@ -3383,7 +3407,7 @@ class _GeneratorWorkerMt(threading.Thread):
def __init__(self, dataset, eoe, eof):
self.idx_queue = queue.Queue(16)
self.res_queue = queue.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))

def put(self, item):
"""
@@ -3403,10 +3427,10 @@ class _GeneratorWorkerMp(multiprocessing.Process):
Worker process for multiprocess Generator.
"""

def __init__(self, dataset, eoe, eof):
def __init__(self, dataset, eoe):
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe))

def put(self, item):
"""


Loading…
Cancel
Save