From e53bab18f90fc1772bae64e822b57798a222fb82 Mon Sep 17 00:00:00 2001 From: Jamie Nisbet Date: Wed, 28 Apr 2021 13:59:26 -0400 Subject: [PATCH] fix MR sampler with replacement case --- mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc | 3 +-- mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc index 92d1addd00..b7c149a419 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc @@ -50,7 +50,7 @@ MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); } - tasks.permutation_.clear(); // Jamie replace this we setting flag to false or something + tasks.permutation_.clear(); for (uint32_t j = 0; j < individual_size; j++) { for (uint32_t i = 0; i < tasks.categories; i++) { tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); @@ -82,7 +82,6 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { MS_LOG(ERROR) << "no_of_samples need to be positive."; return FAILED; } - new_tasks.task_list_.reserve(no_of_samples_); for (uint32_t i = 0; i < no_of_samples_; ++i) { new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc index 0210d2057d..c55d783aef 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc @@ -94,7 +94,7 @@ int ShardTaskList::GetTaskSampleByID(size_t id) { int ShardTaskList::GetRandomTaskID() { std::mt19937 gen = mindspore::dataset::GetRandomDevice(); - std::uniform_int_distribution<> dis(0, task_list_.size() - 1); + std::uniform_int_distribution<> dis(0, sample_ids_.size() - 1); return dis(gen); }