|
|
|
@@ -3452,6 +3452,7 @@ class SamplerFn: |
|
|
|
self.num_worker = num_worker |
|
|
|
self.multi_process = multi_process |
|
|
|
self.joined = False |
|
|
|
self.ppid = os.getpid() |
|
|
|
# Event for end of epoch |
|
|
|
if multi_process is True: |
|
|
|
self.eof = multiprocessing.Event() |
|
|
|
@@ -3510,11 +3511,12 @@ class SamplerFn: |
|
|
|
yield tuple([np.array(x, copy=False) for x in result]) |
|
|
|
|
|
|
|
def _stop_subprocess(self): |
|
|
|
self.eof.set() |
|
|
|
if self.joined is False: |
|
|
|
# Only the main process can call join |
|
|
|
if self.joined is False and self.ppid == os.getpid(): |
|
|
|
self.eof.set() |
|
|
|
self.joined = True |
|
|
|
for w in self.workers: |
|
|
|
w.join() |
|
|
|
self.joined = True |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
self._stop_subprocess() |
|
|
|
|