Browse Source

fix MR sampler with replacement case

pull/15859/head
Jamie Nisbet 4 years ago
parent
commit
e53bab18f9
2 changed files with 2 additions and 3 deletions
  1. +1
    -2
      mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc

+ 1
- 2
mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc View File

@@ -50,7 +50,7 @@ MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
}
tasks.permutation_.clear(); // Jamie replace this we setting flag to false or something
tasks.permutation_.clear();
for (uint32_t j = 0; j < individual_size; j++) {
for (uint32_t i = 0; i < tasks.categories; i++) {
tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
@@ -82,7 +82,6 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
new_tasks.task_list_.reserve(no_of_samples_);
for (uint32_t i = 0; i < no_of_samples_; ++i) {
new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
}


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc View File

@@ -94,7 +94,7 @@ int ShardTaskList::GetTaskSampleByID(size_t id) {

int ShardTaskList::GetRandomTaskID() {
std::mt19937 gen = mindspore::dataset::GetRandomDevice();
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
std::uniform_int_distribution<> dis(0, sample_ids_.size() - 1);
return dis(gen);
}



Loading…
Cancel
Save