Merge pull request !711 from ZiruiWu/mindspore_master_onlytags/v0.3.0-alpha
| @@ -53,6 +53,7 @@ Status RandomSampler::InitSampler() { | |||||
| num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; | num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); | CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); | ||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | ||||
| rnd_.seed(seed_++); | |||||
| if (replacement_ == false) { | if (replacement_ == false) { | ||||
| shuffled_ids_.reserve(num_rows_); | shuffled_ids_.reserve(num_rows_); | ||||
| for (int64_t i = 0; i < num_rows_; i++) { | for (int64_t i = 0; i < num_rows_; i++) { | ||||
| @@ -62,7 +63,6 @@ Status RandomSampler::InitSampler() { | |||||
| } else { | } else { | ||||
| dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | ||||
| } | } | ||||
| rnd_.seed(seed_++); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -20,7 +20,7 @@ DATA_DIR = "../data/dataset/testCelebAData/" | |||||
| def test_celeba_dataset_label(): | def test_celeba_dataset_label(): | ||||
| data = ds.CelebADataset(DATA_DIR, decode=True) | |||||
| data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False) | |||||
| expect_labels = [ | expect_labels = [ | ||||
| [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, | [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, | ||||
| 0, 0, 1], | 0, 0, 1], | ||||