|
|
|
@@ -49,16 +49,6 @@ Status WeightedRandomSampler::InitSampler() { |
|
|
|
"Invalid parameter, samples_per_buffer must be greater than 0, but got " + |
|
|
|
std::to_string(samples_per_buffer_) + ".\n"); |
|
|
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(weights_.size() != 0, "Invalid parameter, weights size must not be 0.\n"); |
|
|
|
int32_t zero_elem = 0; |
|
|
|
for (auto &elem : weights_) { |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(elem >= 0.0, "Invalid parameter, weights must not contain negative number, but got " + |
|
|
|
std::to_string(elem) + ".\n"); |
|
|
|
if (elem == 0.0) zero_elem++; |
|
|
|
} |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(zero_elem != weights_.size(), |
|
|
|
"Invalid parameter, elements of weights must not be all zero.\n"); |
|
|
|
|
|
|
|
if (weights_.size() > static_cast<size_t>(num_rows_)) { |
|
|
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, |
|
|
|
"Invalid parameter, size of sample weights must be less than or equal to num of data, " |
|
|
|
|