Browse Source

!671 Added testcase for sync_wait

Merge pull request !671 from EricZ/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
f82e63fecc
3 changed files with 34 additions and 4 deletions
  1. +2
    -2
      mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc
  2. +1
    -1
      mindspore/dataset/engine/datasets.py
  3. +31
    -1
      tests/ut/python/dataset/test_sync_wait.py

+ 2
- 2
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc View File

@@ -65,8 +65,8 @@ Status BarrierOp::operator()() {
TaskManager::FindMe()->Post();

// create child iterator, right now this barrier is a pipeline operator
int32_t worker_id = 0;
int32_t child_idx = 0;
const int32_t worker_id = 0;
const int32_t child_idx = 0;
child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);

// Loop until eof is true


+ 1
- 1
mindspore/dataset/engine/datasets.py View File

@@ -922,7 +922,7 @@ class Dataset:
def sync_update(self, condition_name, num_batch=None, data=None):
"""
condition_name (str): The condition name that is used to toggle sending next row
step_size (int or None): The number of steps(rows) that are released
num_batch (int or None): The number of batches(rows) that are released
when pass_rows is None, will update the same number as sync_wait specified
data (dict or None): The data passed to the callback
"""


+ 31
- 1
tests/ut/python/dataset/test_sync_wait.py View File

@@ -107,6 +107,7 @@ def test_two_sync():
if count % 2 == 0:
dataset.sync_update(condition_name="every 2 batches")


def test_sync_epoch():
"""
Test sync wait with epochs: test sync with epochs in dataset pipeline
@@ -130,6 +131,34 @@ def test_sync_epoch():
dataset.sync_update(condition_name="policy", data=data)


def test_multiple_iterators():
"""
Test sync wait with multiple iterators: will start multiple
"""
logger.info("test_sync_epoch")
batch_size = 30
dataset = ds.GeneratorDataset(gen, column_names=["input"])

aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True)
# 2nd dataset
dataset2 = ds.GeneratorDataset(gen, column_names=["input"])

aug = Augment(0)
dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
dataset2 = dataset2.batch(batch_size, drop_remainder=True)

for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
assert (item1["input"][0] == item2["input"][0])
data1 = {"loss": item1["input"][0]}
data2 = {"loss": item2["input"][0]}
dataset.sync_update(condition_name="policy", data=data1)
dataset2.sync_update(condition_name="policy", data=data2)


def test_sync_exception_01():
"""
Test sync: with shuffle in sync mode
@@ -179,4 +208,5 @@ if __name__ == "__main__":
test_two_sync()
test_sync_exception_01()
test_sync_exception_02()
test_sync_epoch()
test_sync_epoch()
test_multiple_iterators()

Loading…
Cancel
Save