Browse Source

changed sampler parameter to support std::move

tags/v0.7.0-beta
ervinzhang 5 years ago
parent
commit
c65670b1a7
3 changed files with 25 additions and 18 deletions
  1. +8
    -9
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  2. +7
    -9
      mindspore/ccsrc/minddata/dataset/include/samplers.h
  3. +10
    -0
      tests/ut/cpp/dataset/c_api_test.cc

+ 8
- 9
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -71,8 +71,8 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int
} }


/// Function to create a Subset Random Sampler. /// Function to create a Subset Random Sampler.
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples) {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples);
// Input validation // Input validation
if (!sampler->ValidateParams()) { if (!sampler->ValidateParams()) {
return nullptr; return nullptr;
@@ -81,9 +81,9 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in
} }


/// Function to create a Weighted Random Sampler. /// Function to create a Weighted Random Sampler.
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples,
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples,
bool replacement) { bool replacement) {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement);
// Input validation // Input validation
if (!sampler->ValidateParams()) { if (!sampler->ValidateParams()) {
return nullptr; return nullptr;
@@ -190,8 +190,8 @@ std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
} }


// SubsetRandomSampler // SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples)
: indices_(indices), num_samples_(num_samples) {}
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {}


bool SubsetRandomSamplerObj::ValidateParams() { bool SubsetRandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) { if (num_samples_ < 0) {
@@ -208,9 +208,8 @@ std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
} }


// WeightedRandomSampler // WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples,
bool replacement)
: weights_(weights), num_samples_(num_samples), replacement_(replacement) {}
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}


bool WeightedRandomSamplerObj::ValidateParams() { bool WeightedRandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) { if (num_samples_ < 0) {


+ 7
- 9
mindspore/ccsrc/minddata/dataset/include/samplers.h View File

@@ -87,8 +87,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0,
/// \param[in] indices - A vector sequence of indices. /// \param[in] indices - A vector sequence of indices.
/// \param[in] num_samples - The number of samples to draw (default to all elements). /// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \return Shared pointer to the current Sampler. /// \return Shared pointer to the current Sampler.
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices,
int64_t num_samples = 0);
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);


/// Function to create a Weighted Random Sampler. /// Function to create a Weighted Random Sampler.
/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
@@ -97,8 +96,8 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in
/// \param[in] num_samples - The number of samples to draw (default to all elements). /// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \param[in] replacement - If True, put the sample ID back for the next draw. /// \param[in] replacement - If True, put the sample ID back for the next draw.
/// \return Shared pointer to the current Sampler. /// \return Shared pointer to the current Sampler.
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights,
int64_t num_samples = 0, bool replacement = true);
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
bool replacement = true);


/* ####################################### Derived Sampler classes ################################# */ /* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj { class DistributedSamplerObj : public SamplerObj {
@@ -169,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj {


class SubsetRandomSamplerObj : public SamplerObj { class SubsetRandomSamplerObj : public SamplerObj {
public: public:
SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples);
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);


~SubsetRandomSamplerObj() = default; ~SubsetRandomSamplerObj() = default;


@@ -178,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj {
bool ValidateParams() override; bool ValidateParams() override;


private: private:
const std::vector<int64_t> &indices_;
const std::vector<int64_t> indices_;
int64_t num_samples_; int64_t num_samples_;
}; };


class WeightedRandomSamplerObj : public SamplerObj { class WeightedRandomSamplerObj : public SamplerObj {
public: public:
explicit WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples = 0,
bool replacement = true);
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);


~WeightedRandomSamplerObj() = default; ~WeightedRandomSamplerObj() = default;


@@ -194,7 +192,7 @@ class WeightedRandomSamplerObj : public SamplerObj {
bool ValidateParams() override; bool ValidateParams() override;


private: private:
const std::vector<double> &weights_;
const std::vector<double> weights_;
int64_t num_samples_; int64_t num_samples_;
bool replacement_; bool replacement_;
}; };


+ 10
- 0
tests/ut/cpp/dataset/c_api_test.cc View File

@@ -369,6 +369,16 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices);
EXPECT_FALSE(indices.empty());
EXPECT_NE(sampl1->Build(), nullptr);
std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices));
EXPECT_TRUE(indices.empty());
EXPECT_NE(sampl2->Build(), nullptr);
}

TEST_F(MindDataTestPipeline, TestPad) { TEST_F(MindDataTestPipeline, TestPad) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad."; MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad.";




Loading…
Cancel
Save