Browse Source

fix weighted random sampler core

tags/v0.7.0-beta
yanghaitao 5 years ago
parent
commit
6cf0c29461
1 changed files with 7 additions and 0 deletions
  1. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc

+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc View File

@@ -44,6 +44,13 @@ Status WeightedRandomSampler::InitSampler() {
}
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive");
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
if (weights_.size() > static_cast<size_t>(num_rows_)) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
}
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples");
}

// Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed());


Loading…
Cancel
Save