From 75bf8a471485111fe0e0cce7b247ec165ae4db43 Mon Sep 17 00:00:00 2001 From: heleiwang Date: Mon, 1 Mar 2021 17:39:24 +0800 Subject: [PATCH] fix multithreading error in GeneratorDataset --- mindspore/dataset/engine/datasets.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 292de85c0f..231d7b2fe7 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3174,7 +3174,7 @@ class SamplerFn: self.workers = [] self.num_worker = num_worker self.multi_process = multi_process - self.joined = False + self.need_join = False self.ppid = os.getpid() self.pid = [] # Event for end of epoch @@ -3192,6 +3192,7 @@ class SamplerFn: # In this phase, the main process is not locked. worker.start() self.pid.append(worker.pid) + self.need_join = True else: worker = _GeneratorWorkerMt(dataset, self.eof) worker.daemon = True @@ -3237,9 +3238,9 @@ class SamplerFn: def _stop_subprocess(self): # Only the main process can call join - if self.joined is False and self.ppid == os.getpid(): + if self.need_join is True and self.ppid == os.getpid(): self.eof.set() - self.joined = True + self.need_join = False for w in self.workers: w.join()