Browse Source

made shuffle determinisitc for subsequent epochs

tags/v0.3.0-alpha
Peilin Wang 5 years ago
parent
commit
0cbcc7200b
3 changed files with 8 additions and 8 deletions
  1. +6
    -6
      mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc
  2. BIN
      tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz
  3. +2
    -2
      tests/ut/python/dataset/test_2ops.py

+ 6
- 6
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc View File

@@ -83,14 +83,14 @@ ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_con
// itself rather than waiting for the reset driven from operators above it in the pipeline. // itself rather than waiting for the reset driven from operators above it in the pipeline.
Status ShuffleOp::SelfReset() { Status ShuffleOp::SelfReset() {
MS_LOG(DEBUG) << "Shuffle operator performing a self-reset."; MS_LOG(DEBUG) << "Shuffle operator performing a self-reset.";
// If ReshuffleEachEpoch is false, then we always use the same seed for every
// If reshuffle_each_epoch is false, then we always use the same seed for every
// epoch. // epoch.
// If ReshuffleEachEpoch is true, then the first epoch uses the given seed,
// and all subsequent epochs will then reset the seed based on random device.
if (reshuffle_each_epoch_) {
shuffle_seed_ = GetNewSeed();
// If reshuffle_each_epoch is true, then the first epoch uses the given seed,
// and all subsequent epochs will then keep on using the rng_ without resetting it
if (!reshuffle_each_epoch_) {
rng_ = std::mt19937_64(shuffle_seed_);
} }
rng_ = std::mt19937_64(shuffle_seed_);
shuffle_buffer_ = std::make_unique<TensorTable>(); shuffle_buffer_ = std::make_unique<TensorTable>();
buffer_counter_ = 0; buffer_counter_ = 0;
shuffle_last_row_idx_ = 0; shuffle_last_row_idx_ = 0;


BIN
tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz View File


+ 2
- 2
tests/ut/python/dataset/test_2ops.py View File

@@ -47,7 +47,7 @@ def test_2ops_repeat_shuffle():
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)




def skip_test_2ops_shuffle_repeat():
def test_2ops_shuffle_repeat():
""" """
Test Shuffle then Repeat Test Shuffle then Repeat
""" """
@@ -159,7 +159,7 @@ def test_2ops_shuffle_batch():


if __name__ == '__main__': if __name__ == '__main__':
test_2ops_repeat_shuffle() test_2ops_repeat_shuffle()
# test_2ops_shuffle_repeat()
test_2ops_shuffle_repeat()
test_2ops_repeat_batch() test_2ops_repeat_batch()
test_2ops_batch_repeat() test_2ops_batch_repeat()
test_2ops_batch_shuffle() test_2ops_batch_shuffle()


Loading…
Cancel
Save