diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc index 92d1addd00..b7c149a419 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc @@ -50,7 +50,7 @@ MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); } - tasks.permutation_.clear(); // Jamie replace this we setting flag to false or something + tasks.permutation_.clear(); for (uint32_t j = 0; j < individual_size; j++) { for (uint32_t i = 0; i < tasks.categories; i++) { tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); @@ -82,7 +82,6 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { MS_LOG(ERROR) << "no_of_samples need to be positive."; return FAILED; } - new_tasks.task_list_.reserve(no_of_samples_); for (uint32_t i = 0; i < no_of_samples_; ++i) { new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc index 0210d2057d..c55d783aef 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc @@ -94,7 +94,7 @@ int ShardTaskList::GetTaskSampleByID(size_t id) { int ShardTaskList::GetRandomTaskID() { std::mt19937 gen = mindspore::dataset::GetRandomDevice(); - std::uniform_int_distribution<> dis(0, task_list_.size() - 1); + std::uniform_int_distribution<> dis(0, sample_ids_.size() - 1); return dis(gen); }