Browse Source

!6943 [MD] minddata use new ps api when ps mode

Merge pull request !6943 from xiefangqi/md_replace_ps_api
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
42e1fc8e24
2 changed files with 7 additions and 4 deletions
  1. +2
    -4
      mindspore/dataset/engine/datasets.py
  2. +5
    -0
      tests/ut/python/dataset/test_noop_mode.py

+ 2
- 4
mindspore/dataset/engine/datasets.py View File

@@ -37,6 +37,7 @@ from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp
from mindspore._c_expression import typing from mindspore._c_expression import typing


from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched


import mindspore.dataset.transforms.py_transforms as py_transforms import mindspore.dataset.transforms.py_transforms as py_transforms


@@ -152,10 +153,9 @@ class Dataset:
self._num_classes = None self._num_classes = None
self._repeat_count = None self._repeat_count = None
self._sync = False self._sync = False
self.ms_role = os.getenv("MS_ROLE")


def _noop_mode(self): def _noop_mode(self):
if self.ms_role in ("MS_PSERVER", "MS_SCHED"):
if _is_role_sched() or _is_role_pserver():
return True return True
return False return False


@@ -2146,7 +2146,6 @@ class MapDataset(DatasetOp):
new_op.column_order = copy.deepcopy(self.column_order, memodict) new_op.column_order = copy.deepcopy(self.column_order, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict) new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.ms_role = copy.deepcopy(self.ms_role, memodict)
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.cache = copy.deepcopy(self.cache, memodict) new_op.cache = copy.deepcopy(self.cache, memodict)
@@ -3622,7 +3621,6 @@ class GeneratorDataset(MappableDataset):
memodict[id(self)] = new_op memodict[id(self)] = new_op
new_op.children = copy.deepcopy(self.children, memodict) new_op.children = copy.deepcopy(self.children, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict) new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.ms_role = copy.deepcopy(self.ms_role, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.column_types = copy.deepcopy(self.column_types, memodict) new_op.column_types = copy.deepcopy(self.column_types, memodict)
new_op.column_names = copy.deepcopy(self.column_names, memodict) new_op.column_names = copy.deepcopy(self.column_names, memodict)


+ 5
- 0
tests/ut/python/dataset/test_noop_mode.py View File

@@ -17,27 +17,32 @@ Test No-op mode support with Dummy Iterator
""" """
import os import os
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import context


DATA_DIR = "../data/dataset/testVOC2012" DATA_DIR = "../data/dataset/testVOC2012"


def test_noop_pserver(): def test_noop_pserver():
os.environ['MS_ROLE'] = 'MS_PSERVER' os.environ['MS_ROLE'] = 'MS_PSERVER'
context.set_ps_context(enable_ps=True)
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
num = 0 num = 0
for _ in data1.create_dict_iterator(num_epochs=1): for _ in data1.create_dict_iterator(num_epochs=1):
num += 1 num += 1
assert num == 0 assert num == 0
del os.environ['MS_ROLE'] del os.environ['MS_ROLE']
context.set_ps_context(enable_ps=False)




def test_noop_sched(): def test_noop_sched():
os.environ['MS_ROLE'] = 'MS_SCHED' os.environ['MS_ROLE'] = 'MS_SCHED'
context.set_ps_context(enable_ps=True)
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
num = 0 num = 0
for _ in data1.create_dict_iterator(num_epochs=1): for _ in data1.create_dict_iterator(num_epochs=1):
num += 1 num += 1
assert num == 0 assert num == 0
del os.environ['MS_ROLE'] del os.environ['MS_ROLE']
context.set_ps_context(enable_ps=False)




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save