| @@ -299,6 +299,24 @@ WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, | |||||
| : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | ||||
| bool WeightedRandomSamplerObj::ValidateParams() { | bool WeightedRandomSamplerObj::ValidateParams() { | ||||
| if (weights_.empty()) { | |||||
| MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty"; | |||||
| return false; | |||||
| } | |||||
| int32_t zero_elem = 0; | |||||
| for (int32_t i = 0; i < weights_.size(); ++i) { | |||||
| if (weights_[i] < 0) { | |||||
| MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i]; | |||||
| return false; | |||||
| } | |||||
| if (weights_[i] == 0.0) { | |||||
| zero_elem++; | |||||
| } | |||||
| } | |||||
| if (zero_elem == weights_.size()) { | |||||
| MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero"; | |||||
| return false; | |||||
| } | |||||
| if (num_samples_ < 0) { | if (num_samples_ < 0) { | ||||
| MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; | MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; | ||||
| return false; | return false; | ||||
| @@ -103,7 +103,7 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| int64_t sampled_id = indices_[sample_id_]; | |||||
| int64_t sampled_id = ((indices_[sample_id_] % num_rows_) + num_rows_) % num_rows_; | |||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); | RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); | ||||
| } | } | ||||
| @@ -585,6 +585,15 @@ class WeightedRandomSampler(BuiltinSampler): | |||||
| if not isinstance(weights, list): | if not isinstance(weights, list): | ||||
| weights = [weights] | weights = [weights] | ||||
| if weights == []: | |||||
| raise ValueError("weights size should not be 0") | |||||
| if list(filter(lambda x: x < 0, weights)): | |||||
| raise ValueError("weights should not contain negative numbers") | |||||
| if list(filter(lambda x: x == 0, weights)) == weights: | |||||
| raise ValueError("elements of weights should not be all zero") | |||||
| if num_samples is not None: | if num_samples is not None: | ||||
| if num_samples <= 0: | if num_samples <= 0: | ||||
| raise ValueError("num_samples should be a positive integer " | raise ValueError("num_samples should be a positive integer " | ||||
| @@ -99,3 +99,20 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { | |||||
| EXPECT_TRUE(indices.empty()); | EXPECT_TRUE(indices.empty()); | ||||
| EXPECT_NE(sampl2->Build(), nullptr); | EXPECT_NE(sampl2->Build(), nullptr); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) { | |||||
| // weights is empty | |||||
| std::vector<double> weights1 = {}; | |||||
| std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights1); | |||||
| EXPECT_EQ(sampl1, nullptr); | |||||
| // weights has negative number | |||||
| std::vector<double> weights2 = {0.5, 0.2, -0.4}; | |||||
| std::shared_ptr<SamplerObj> sampl2 = WeightedRandomSampler(weights2); | |||||
| EXPECT_EQ(sampl2, nullptr); | |||||
| // weights elements are all zero | |||||
| std::vector<double> weights3 = {0, 0, 0}; | |||||
| std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights3); | |||||
| EXPECT_EQ(sampl3, nullptr); | |||||
| } | |||||