diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc index d0ec89671f..69f7e09986 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc @@ -83,14 +83,14 @@ ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_con // itself rather than waiting for the reset driven from operators above it in the pipeline. Status ShuffleOp::SelfReset() { MS_LOG(DEBUG) << "Shuffle operator performing a self-reset."; - // If ReshuffleEachEpoch is false, then we always use the same seed for every + // If reshuffle_each_epoch is false, then we always use the same seed for every // epoch. - // If ReshuffleEachEpoch is true, then the first epoch uses the given seed, - // and all subsequent epochs will then reset the seed based on random device. - if (reshuffle_each_epoch_) { - shuffle_seed_ = GetNewSeed(); + // If reshuffle_each_epoch is true, then the first epoch uses the given seed, + // and all subsequent epochs will then keep on using the rng_ without resetting it + if (!reshuffle_each_epoch_) { + rng_ = std::mt19937_64(shuffle_seed_); } - rng_ = std::mt19937_64(shuffle_seed_); + shuffle_buffer_ = std::make_unique(); buffer_counter_ = 0; shuffle_last_row_idx_ = 0; diff --git a/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz b/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz index 44af6db750..5c39c0e64b 100644 Binary files a/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz and b/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz differ diff --git a/tests/ut/python/dataset/test_2ops.py b/tests/ut/python/dataset/test_2ops.py index cdf59ecf22..ef60a42e27 100644 --- a/tests/ut/python/dataset/test_2ops.py +++ b/tests/ut/python/dataset/test_2ops.py @@ -47,7 +47,7 @@ def test_2ops_repeat_shuffle(): save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) -def skip_test_2ops_shuffle_repeat(): +def test_2ops_shuffle_repeat(): """ Test Shuffle then Repeat """ @@ -159,7 +159,7 @@ def test_2ops_shuffle_batch(): if __name__ == '__main__': test_2ops_repeat_shuffle() - # test_2ops_shuffle_repeat() + test_2ops_shuffle_repeat() test_2ops_repeat_batch() test_2ops_batch_repeat() test_2ops_batch_shuffle()