|
|
@@ -87,8 +87,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0, |
|
|
/// \param[in] indices - A vector sequence of indices. |
|
|
/// \param[in] indices - A vector sequence of indices. |
|
|
/// \param[in] num_samples - The number of samples to draw (default to all elements). |
|
|
/// \param[in] num_samples - The number of samples to draw (default to all elements). |
|
|
/// \return Shared pointer to the current Sampler. |
|
|
/// \return Shared pointer to the current Sampler. |
|
|
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices, |
|
|
|
|
|
int64_t num_samples = 0); |
|
|
|
|
|
|
|
|
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0); |
|
|
|
|
|
|
|
|
/// Function to create a Weighted Random Sampler. |
|
|
/// Function to create a Weighted Random Sampler. |
|
|
/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given |
|
|
/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given |
|
|
@@ -97,8 +96,8 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in |
|
|
/// \param[in] num_samples - The number of samples to draw (default to all elements). |
|
|
/// \param[in] num_samples - The number of samples to draw (default to all elements). |
|
|
/// \param[in] replacement - If True, put the sample ID back for the next draw. |
|
|
/// \param[in] replacement - If True, put the sample ID back for the next draw. |
|
|
/// \return Shared pointer to the current Sampler. |
|
|
/// \return Shared pointer to the current Sampler. |
|
|
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights, |
|
|
|
|
|
int64_t num_samples = 0, bool replacement = true); |
|
|
|
|
|
|
|
|
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, |
|
|
|
|
|
bool replacement = true); |
|
|
|
|
|
|
|
|
/* ####################################### Derived Sampler classes ################################# */ |
|
|
/* ####################################### Derived Sampler classes ################################# */ |
|
|
class DistributedSamplerObj : public SamplerObj { |
|
|
class DistributedSamplerObj : public SamplerObj { |
|
|
@@ -169,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj { |
|
|
|
|
|
|
|
|
class SubsetRandomSamplerObj : public SamplerObj { |
|
|
class SubsetRandomSamplerObj : public SamplerObj { |
|
|
public: |
|
|
public: |
|
|
SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples); |
|
|
|
|
|
|
|
|
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples); |
|
|
|
|
|
|
|
|
~SubsetRandomSamplerObj() = default; |
|
|
~SubsetRandomSamplerObj() = default; |
|
|
|
|
|
|
|
|
@@ -178,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj { |
|
|
bool ValidateParams() override; |
|
|
bool ValidateParams() override; |
|
|
|
|
|
|
|
|
private: |
|
|
private: |
|
|
const std::vector<int64_t> &indices_; |
|
|
|
|
|
|
|
|
const std::vector<int64_t> indices_; |
|
|
int64_t num_samples_; |
|
|
int64_t num_samples_; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
class WeightedRandomSamplerObj : public SamplerObj { |
|
|
class WeightedRandomSamplerObj : public SamplerObj { |
|
|
public: |
|
|
public: |
|
|
explicit WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples = 0, |
|
|
|
|
|
bool replacement = true); |
|
|
|
|
|
|
|
|
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true); |
|
|
|
|
|
|
|
|
~WeightedRandomSamplerObj() = default; |
|
|
~WeightedRandomSamplerObj() = default; |
|
|
|
|
|
|
|
|
@@ -194,7 +192,7 @@ class WeightedRandomSamplerObj : public SamplerObj { |
|
|
bool ValidateParams() override; |
|
|
bool ValidateParams() override; |
|
|
|
|
|
|
|
|
private: |
|
|
private: |
|
|
const std::vector<double> &weights_; |
|
|
|
|
|
|
|
|
const std::vector<double> weights_; |
|
|
int64_t num_samples_; |
|
|
int64_t num_samples_; |
|
|
bool replacement_; |
|
|
bool replacement_; |
|
|
}; |
|
|
}; |
|
|
|