Browse Source

!7730 [MD] Fix bugs in WeightedRandomSampler & SubsetRandomSampler

Merge pull request !7730 from luoyang/c-api-pyfunc
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7982f05038
4 changed files with 45 additions and 1 deletions
  1. +18
    -0
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
  3. +9
    -0
      mindspore/dataset/engine/samplers.py
  4. +17
    -0
      tests/ut/cpp/dataset/c_api_samplers_test.cc

+ 18
- 0
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -299,6 +299,24 @@ WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights,
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}

bool WeightedRandomSamplerObj::ValidateParams() {
if (weights_.empty()) {
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty";
return false;
}
int32_t zero_elem = 0;
for (int32_t i = 0; i < weights_.size(); ++i) {
if (weights_[i] < 0) {
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i];
return false;
}
if (weights_[i] == 0.0) {
zero_elem++;
}
}
if (zero_elem == weights_.size()) {
MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero";
return false;
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_;
return false;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc View File

@@ -103,7 +103,7 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
RETURN_STATUS_UNEXPECTED(err_msg);
}

int64_t sampled_id = indices_[sample_id_];
int64_t sampled_id = ((indices_[sample_id_] % num_rows_) + num_rows_) % num_rows_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}


+ 9
- 0
mindspore/dataset/engine/samplers.py View File

@@ -585,6 +585,15 @@ class WeightedRandomSampler(BuiltinSampler):
if not isinstance(weights, list):
weights = [weights]

if weights == []:
raise ValueError("weights size should not be 0")

if list(filter(lambda x: x < 0, weights)):
raise ValueError("weights should not contain negative numbers")

if list(filter(lambda x: x == 0, weights)) == weights:
raise ValueError("elements of weights should not be all zero")

if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "


+ 17
- 0
tests/ut/cpp/dataset/c_api_samplers_test.cc View File

@@ -99,3 +99,20 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
EXPECT_TRUE(indices.empty());
EXPECT_NE(sampl2->Build(), nullptr);
}

TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {
// weights is empty
std::vector<double> weights1 = {};
std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights1);
EXPECT_EQ(sampl1, nullptr);

// weights has negative number
std::vector<double> weights2 = {0.5, 0.2, -0.4};
std::shared_ptr<SamplerObj> sampl2 = WeightedRandomSampler(weights2);
EXPECT_EQ(sampl2, nullptr);

// weights elements are all zero
std::vector<double> weights3 = {0, 0, 0};
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights3);
EXPECT_EQ(sampl3, nullptr);
}

Loading…
Cancel
Save