diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 9b9c2c967d..51aa020006 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -37,6 +37,7 @@ from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp from mindspore._c_expression import typing from mindspore import log as logger +from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched import mindspore.dataset.transforms.py_transforms as py_transforms @@ -152,10 +153,9 @@ class Dataset: self._num_classes = None self._repeat_count = None self._sync = False - self.ms_role = os.getenv("MS_ROLE") def _noop_mode(self): - if self.ms_role in ("MS_PSERVER", "MS_SCHED"): + if _is_role_sched() or _is_role_pserver(): return True return False @@ -2146,7 +2146,6 @@ class MapDataset(DatasetOp): new_op.column_order = copy.deepcopy(self.column_order, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.parent = copy.deepcopy(self.parent, memodict) - new_op.ms_role = copy.deepcopy(self.ms_role, memodict) new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.cache = copy.deepcopy(self.cache, memodict) @@ -3622,7 +3621,6 @@ class GeneratorDataset(MappableDataset): memodict[id(self)] = new_op new_op.children = copy.deepcopy(self.children, memodict) new_op.parent = copy.deepcopy(self.parent, memodict) - new_op.ms_role = copy.deepcopy(self.ms_role, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.column_types = copy.deepcopy(self.column_types, memodict) new_op.column_names = copy.deepcopy(self.column_names, memodict) diff --git a/tests/ut/python/dataset/test_noop_mode.py b/tests/ut/python/dataset/test_noop_mode.py index f02180df34..0e2eaf40fe 100644 --- a/tests/ut/python/dataset/test_noop_mode.py +++ b/tests/ut/python/dataset/test_noop_mode.py @@ -17,27 +17,32 @@ Test No-op mode support with Dummy Iterator """ import os import mindspore.dataset as ds +from mindspore import context DATA_DIR = "../data/dataset/testVOC2012" def test_noop_pserver(): os.environ['MS_ROLE'] = 'MS_PSERVER' + context.set_ps_context(enable_ps=True) data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) num = 0 for _ in data1.create_dict_iterator(num_epochs=1): num += 1 assert num == 0 del os.environ['MS_ROLE'] + context.set_ps_context(enable_ps=False) def test_noop_sched(): os.environ['MS_ROLE'] = 'MS_SCHED' + context.set_ps_context(enable_ps=True) data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) num = 0 for _ in data1.create_dict_iterator(num_epochs=1): num += 1 assert num == 0 del os.environ['MS_ROLE'] + context.set_ps_context(enable_ps=False) if __name__ == '__main__':