|
|
@@ -50,11 +50,15 @@ Status WeightedRandomSampler::InitSampler() { |
|
|
std::to_string(samples_per_buffer_) + ".\n"); |
|
|
std::to_string(samples_per_buffer_) + ".\n"); |
|
|
if (weights_.size() > static_cast<size_t>(num_rows_)) { |
|
|
if (weights_.size() > static_cast<size_t>(num_rows_)) { |
|
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, |
|
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, |
|
|
"Invalid parameter, number of samples weights is more than num of rows. " |
|
|
|
|
|
"Might generate id out of bound OR other errors"); |
|
|
|
|
|
|
|
|
"Invalid parameter, size of sample weights must be less than or equal to num of data, " |
|
|
|
|
|
"otherwise might cause generated id out of bound or other errors, but got weight size: " + |
|
|
|
|
|
std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_)); |
|
|
} |
|
|
} |
|
|
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { |
|
|
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { |
|
|
RETURN_STATUS_UNEXPECTED("Invalid parameter, without replacement, weights size must be greater than num_samples."); |
|
|
|
|
|
|
|
|
RETURN_STATUS_UNEXPECTED( |
|
|
|
|
|
"Invalid parameter, without replacement, weights size must be greater than or equal to num_samples, " |
|
|
|
|
|
"but got weight size: " + |
|
|
|
|
|
std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// Initialize random generator with seed from config manager |
|
|
// Initialize random generator with seed from config manager |
|
|
@@ -110,11 +114,16 @@ Status WeightedRandomSampler::ResetSampler() { |
|
|
Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { |
|
|
Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { |
|
|
if (weights_.size() > static_cast<size_t>(num_rows_)) { |
|
|
if (weights_.size() > static_cast<size_t>(num_rows_)) { |
|
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, |
|
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, |
|
|
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); |
|
|
|
|
|
|
|
|
"Invalid parameter, size of sample weights must be less than or equal to num of data, " |
|
|
|
|
|
"otherwise might cause generated id out of bound or other errors, but got weight size: " + |
|
|
|
|
|
std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { |
|
|
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { |
|
|
RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); |
|
|
|
|
|
|
|
|
RETURN_STATUS_UNEXPECTED( |
|
|
|
|
|
"Invalid parameter, without replacement, weights size must be greater than or equal to num_samples, " |
|
|
|
|
|
"but got weight size: " + |
|
|
|
|
|
std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (sample_id_ == num_samples_) { |
|
|
if (sample_id_ == num_samples_) { |
|
|
@@ -150,7 +159,8 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (genId >= num_rows_) { |
|
|
if (genId >= num_rows_) { |
|
|
RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound)."); |
|
|
|
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("Generated indice is out of bound, expect range [0, num_data-1], got indice: " + |
|
|
|
|
|
std::to_string(genId) + ", num_data: " + std::to_string(num_rows_ - 1)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (HasChildSampler()) { |
|
|
if (HasChildSampler()) { |
|
|
|