diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index 13863143c0..47c2c8b0d2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -44,6 +44,13 @@ Status WeightedRandomSampler::InitSampler() { } CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive"); CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); + if (weights_.size() > static_cast(num_rows_)) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); + } + if (!replacement_ && (weights_.size() < static_cast(num_samples_))) { + RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); + } // Initialize random generator with seed from config manager rand_gen_.seed(GetSeed());