Added Distributed sampler option Fix styletags/v0.7.0-beta
| @@ -31,8 +31,8 @@ SamplerObj::SamplerObj() {} | |||
| /// Function to create a Distributed Sampler. | |||
| std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, | |||
| int64_t num_samples, uint32_t seed) { | |||
| auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed); | |||
| int64_t num_samples, uint32_t seed, bool even_dist) { | |||
| auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, even_dist); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| return nullptr; | |||
| @@ -95,8 +95,13 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto | |||
| // DistributedSampler | |||
| DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, | |||
| uint32_t seed) | |||
| : num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {} | |||
| uint32_t seed, bool even_dist) | |||
| : num_shards_(num_shards), | |||
| shard_id_(shard_id), | |||
| shuffle_(shuffle), | |||
| num_samples_(num_samples), | |||
| seed_(seed), | |||
| even_dist_(even_dist) {} | |||
| bool DistributedSamplerObj::ValidateParams() { | |||
| if (num_shards_ <= 0) { | |||
| @@ -118,7 +123,8 @@ bool DistributedSamplerObj::ValidateParams() { | |||
| } | |||
| std::shared_ptr<Sampler> DistributedSamplerObj::Build() { | |||
| return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_); | |||
| return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||
| even_dist_); | |||
| } | |||
| // PKSampler | |||
| @@ -24,13 +24,14 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||
| uint32_t seed) | |||
| uint32_t seed, bool even_dist) | |||
| : Sampler(num_samples, std::numeric_limits<int64_t>::max()), | |||
| cnt_(0), | |||
| seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed), | |||
| device_id_(dev_id), | |||
| num_devices_(num_dev), | |||
| shuffle_(shuffle) {} | |||
| shuffle_(shuffle), | |||
| even_dist_(even_dist) {} | |||
| Status DistributedSampler::InitSampler() { | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| @@ -43,7 +44,15 @@ Status DistributedSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, | |||
| "fail to init DistributedSampler"); | |||
| rnd_.seed(seed_++); | |||
| samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) | |||
| if (even_dist_) { | |||
| samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) | |||
| } else { | |||
| int64_t mod = num_rows_ % num_devices_; | |||
| samples_per_buffer_ = num_rows_ / num_devices_; | |||
| if (mod > device_id_) { | |||
| samples_per_buffer_++; | |||
| } | |||
| } | |||
| samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; | |||
| if (shuffle_ == true) { | |||
| shuffle_vec_.reserve(num_rows_); | |||
| @@ -27,26 +27,32 @@ namespace mindspore { | |||
| namespace dataset { | |||
| class DistributedSampler : public Sampler { | |||
| public: | |||
| // @param num_samples | |||
| // @param int64_t num_dev | |||
| // @param int64_t dev_id | |||
| // @param bool shuffle | |||
| /// \brief Constructor | |||
| /// \param[in] num_samples The total number of rows in the dataset | |||
| /// \param[in] num_dev Total number of shards for the distributed sampler | |||
| /// \param[in] dev_id Device id of the shard | |||
| /// \param[in] shuffle Option to shuffle | |||
| /// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will | |||
| /// result in different samples being picked | |||
| /// \param even_dist The option to indicate whether or not each shard returns the same number of rows. | |||
| /// This option is not exposed in the python API. Current behavior is that the remainder will always | |||
| /// be handled by the first n shards, n being the corresponding device id. | |||
| DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||
| uint32_t seed = std::numeric_limits<uint32_t>::max()); | |||
| uint32_t seed = std::numeric_limits<uint32_t>::max(), bool even_dist = true); | |||
| // default destructor | |||
| /// \brief default destructor | |||
| ~DistributedSampler() = default; | |||
| // @param std::unique_ptr<DataBuffer> * pBuffer | |||
| // @param int32_t workerId | |||
| // @return - The error code return | |||
| /// \param std::unique_ptr<DataBuffer> * pBuffer | |||
| /// \param int32_t workerId | |||
| /// \return Status code | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // Init sampler, called by base class or python | |||
| /// Init sampler, called by base class or python | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| /// \brief for next epoch of sampleIds | |||
| /// \return Status code | |||
| Status ResetSampler() override; | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| @@ -59,6 +65,7 @@ class DistributedSampler : public Sampler { | |||
| bool shuffle_; | |||
| std::mt19937 rnd_; | |||
| std::vector<int64_t> shuffle_vec_; | |||
| bool even_dist_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -52,9 +52,12 @@ class WeightedRandomSamplerObj; | |||
| /// \param[in] shuffle - If true, the indices are shuffled. | |||
| /// \param[in] num_samples - The number of samples to draw (default to all elements). | |||
| /// \param[in] seed - The seed in use when shuffle is true. | |||
| /// \param[in] even_dist - If true, each shard would return the same number of rows (default to true). | |||
| /// If false the total rows returned by all the shards would not have overlap. | |||
| /// \return Shared pointer to the current Sampler. | |||
| std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, | |||
| int64_t num_samples = 0, uint32_t seed = 1); | |||
| int64_t num_samples = 0, uint32_t seed = 1, | |||
| bool even_dist = true); | |||
| /// Function to create a PK Sampler. | |||
| /// \notes Samples K elements for each P class in the dataset. | |||
| @@ -100,7 +103,8 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto | |||
| /* ####################################### Derived Sampler classes ################################# */ | |||
| class DistributedSamplerObj : public SamplerObj { | |||
| public: | |||
| DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed); | |||
| DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed, | |||
| bool even_dist); | |||
| ~DistributedSamplerObj() = default; | |||
| @@ -114,6 +118,7 @@ class DistributedSamplerObj : public SamplerObj { | |||
| bool shuffle_; | |||
| int64_t num_samples_; | |||
| uint32_t seed_; | |||
| bool even_dist_; | |||
| }; | |||
| class PKSamplerObj : public SamplerObj { | |||
| @@ -92,7 +92,9 @@ SET(DE_UT_SRCS | |||
| tensor_op_fusion_pass_test.cc | |||
| sliding_window_op_test.cc | |||
| epoch_ctrl_op_test.cc | |||
| swap_red_blue_test.cc | |||
| sentence_piece_vocab_op_test.cc | |||
| swap_red_blue_test.cc | |||
| distributed_sampler_test.cc | |||
| ) | |||
| if (ENABLE_PYTHON) | |||
| @@ -0,0 +1,123 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | |||
| #include "utils/log_adapter.h" | |||
| #include <vector> | |||
| #include <unordered_set> | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestDistributedSampler : public UT::Common { | |||
| public: | |||
| class DummyRandomAccessOp : public RandomAccessOp { | |||
| public: | |||
| DummyRandomAccessOp(uint64_t num_rows) { | |||
| // row count is in base class as protected member | |||
| // GetNumRowsInDataset does not need an override, the default from base class is fine. | |||
| num_rows_ = num_rows; | |||
| } | |||
| }; | |||
| }; | |||
| TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) { | |||
| // num samples to draw. | |||
| uint64_t num_samples = 7; | |||
| // create sampler with replacement = true | |||
| DistributedSampler m_sampler(num_samples, 2, 0, false, 0, false); | |||
| DummyRandomAccessOp dummyRandomAccessOp(num_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| out.push_back(*it); | |||
| } | |||
| } | |||
| ASSERT_EQ(4, out.size()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||
| TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) { | |||
| // num samples to draw. | |||
| uint64_t num_samples = 7; | |||
| // create sampler with replacement = true | |||
| DistributedSampler m_sampler(num_samples, 2, 1, false, 0, false); | |||
| DummyRandomAccessOp dummyRandomAccessOp(num_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| out.push_back(*it); | |||
| } | |||
| } | |||
| ASSERT_EQ(3, out.size()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||
| TEST_F(MindDataTestDistributedSampler, TestThreeShards) { | |||
| // num samples to draw. | |||
| uint64_t num_samples = 2; | |||
| // create sampler with replacement = true | |||
| DistributedSampler m_sampler(num_samples, 3, 2, false, 0, false); | |||
| DummyRandomAccessOp dummyRandomAccessOp(num_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| out.push_back(*it); | |||
| } | |||
| } | |||
| ASSERT_EQ(0, out.size()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||