| @@ -71,8 +71,8 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int | |||
| } | |||
| /// Function to create a Subset Random Sampler. | |||
| std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples) { | |||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples); | |||
| std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) { | |||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| return nullptr; | |||
| @@ -81,9 +81,9 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in | |||
| } | |||
| /// Function to create a Weighted Random Sampler. | |||
| std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, | |||
| std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples, | |||
| bool replacement) { | |||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement); | |||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| return nullptr; | |||
| @@ -190,8 +190,8 @@ std::shared_ptr<Sampler> SequentialSamplerObj::Build() { | |||
| } | |||
| // SubsetRandomSampler | |||
| SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples) | |||
| : indices_(indices), num_samples_(num_samples) {} | |||
| SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | |||
| : indices_(std::move(indices)), num_samples_(num_samples) {} | |||
| bool SubsetRandomSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| @@ -208,9 +208,8 @@ std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() { | |||
| } | |||
| // WeightedRandomSampler | |||
| WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples, | |||
| bool replacement) | |||
| : weights_(weights), num_samples_(num_samples), replacement_(replacement) {} | |||
| WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) | |||
| : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | |||
| bool WeightedRandomSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| @@ -87,8 +87,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0, | |||
| /// \param[in] indices - A vector sequence of indices. | |||
| /// \param[in] num_samples - The number of samples to draw (default to all elements). | |||
| /// \return Shared pointer to the current Sampler. | |||
| std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices, | |||
| int64_t num_samples = 0); | |||
| std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0); | |||
| /// Function to create a Weighted Random Sampler. | |||
| /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given | |||
| @@ -97,8 +96,8 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in | |||
| /// \param[in] num_samples - The number of samples to draw (default to all elements). | |||
| /// \param[in] replacement - If True, put the sample ID back for the next draw. | |||
| /// \return Shared pointer to the current Sampler. | |||
| std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights, | |||
| int64_t num_samples = 0, bool replacement = true); | |||
| std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, | |||
| bool replacement = true); | |||
| /* ####################################### Derived Sampler classes ################################# */ | |||
| class DistributedSamplerObj : public SamplerObj { | |||
| @@ -169,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj { | |||
| class SubsetRandomSamplerObj : public SamplerObj { | |||
| public: | |||
| SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples); | |||
| SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples); | |||
| ~SubsetRandomSamplerObj() = default; | |||
| @@ -178,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj { | |||
| bool ValidateParams() override; | |||
| private: | |||
| const std::vector<int64_t> &indices_; | |||
| const std::vector<int64_t> indices_; | |||
| int64_t num_samples_; | |||
| }; | |||
| class WeightedRandomSamplerObj : public SamplerObj { | |||
| public: | |||
| explicit WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples = 0, | |||
| bool replacement = true); | |||
| explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true); | |||
| ~WeightedRandomSamplerObj() = default; | |||
| @@ -194,7 +192,7 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||
| bool ValidateParams() override; | |||
| private: | |||
| const std::vector<double> &weights_; | |||
| const std::vector<double> weights_; | |||
| int64_t num_samples_; | |||
| bool replacement_; | |||
| }; | |||
| @@ -369,6 +369,16 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { | |||
| std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; | |||
| std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices); | |||
| EXPECT_FALSE(indices.empty()); | |||
| EXPECT_NE(sampl1->Build(), nullptr); | |||
| std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices)); | |||
| EXPECT_TRUE(indices.empty()); | |||
| EXPECT_NE(sampl2->Build(), nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestPad) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad."; | |||