Browse Source

Fixing issues in distributed sampler

Added Distributed sampler option

Fix style
tags/v0.7.0-beta
Eric 5 years ago
parent
commit
8c018da468
6 changed files with 175 additions and 23 deletions
  1. +11
    -5
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  2. +12
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  3. +19
    -12
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h
  4. +7
    -2
      mindspore/ccsrc/minddata/dataset/include/samplers.h
  5. +3
    -1
      tests/ut/cpp/dataset/CMakeLists.txt
  6. +123
    -0
      tests/ut/cpp/dataset/distributed_sampler_test.cc

+ 11
- 5
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -31,8 +31,8 @@ SamplerObj::SamplerObj() {}

/// Function to create a Distributed Sampler.
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
int64_t num_samples, uint32_t seed) {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed);
int64_t num_samples, uint32_t seed, bool even_dist) {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, even_dist);
// Input validation
if (!sampler->ValidateParams()) {
return nullptr;
@@ -95,8 +95,13 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto

// DistributedSampler
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed)
: num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {}
uint32_t seed, bool even_dist)
: num_shards_(num_shards),
shard_id_(shard_id),
shuffle_(shuffle),
num_samples_(num_samples),
seed_(seed),
even_dist_(even_dist) {}

bool DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) {
@@ -118,7 +123,8 @@ bool DistributedSamplerObj::ValidateParams() {
}

std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_);
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
even_dist_);
}

// PKSampler


+ 12
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -24,13 +24,14 @@
namespace mindspore {
namespace dataset {
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed)
uint32_t seed, bool even_dist)
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
cnt_(0),
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
device_id_(dev_id),
num_devices_(num_dev),
shuffle_(shuffle) {}
shuffle_(shuffle),
even_dist_(even_dist) {}

Status DistributedSampler::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
@@ -43,7 +44,15 @@ Status DistributedSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
"fail to init DistributedSampler");
rnd_.seed(seed_++);
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
if (even_dist_) {
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
} else {
int64_t mod = num_rows_ % num_devices_;
samples_per_buffer_ = num_rows_ / num_devices_;
if (mod > device_id_) {
samples_per_buffer_++;
}
}
samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_;
if (shuffle_ == true) {
shuffle_vec_.reserve(num_rows_);


+ 19
- 12
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h View File

@@ -27,26 +27,32 @@ namespace mindspore {
namespace dataset {
class DistributedSampler : public Sampler {
public:
// @param num_samples
// @param int64_t num_dev
// @param int64_t dev_id
// @param bool shuffle
/// \brief Constructor
/// \param[in] num_samples The total number of rows in the dataset
/// \param[in] num_dev Total number of shards for the distributed sampler
/// \param[in] dev_id Device id of the shard
/// \param[in] shuffle Option to shuffle
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
/// result in different samples being picked
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
/// This option is not exposed in the python API. Current behavior is that the remainder will always
/// be handled by the first n shards, n being the corresponding device id.
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max());
uint32_t seed = std::numeric_limits<uint32_t>::max(), bool even_dist = true);

// default destructor
/// \brief default destructor
~DistributedSampler() = default;

// @param std::unique_ptr<DataBuffer> * pBuffer
// @param int32_t workerId
// @return - The error code return
/// \param std::unique_ptr<DataBuffer> * pBuffer
/// \param int32_t workerId
/// \return Status code
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;

// Init sampler, called by base class or python
/// Init sampler, called by base class or python
Status InitSampler() override;

// for next epoch of sampleIds
// @return - The error code return
/// \brief for next epoch of sampleIds
/// \return Status code
Status ResetSampler() override;

void Print(std::ostream &out, bool show_all) const override;
@@ -59,6 +65,7 @@ class DistributedSampler : public Sampler {
bool shuffle_;
std::mt19937 rnd_;
std::vector<int64_t> shuffle_vec_;
bool even_dist_;
};
} // namespace dataset
} // namespace mindspore


+ 7
- 2
mindspore/ccsrc/minddata/dataset/include/samplers.h View File

@@ -52,9 +52,12 @@ class WeightedRandomSamplerObj;
/// \param[in] shuffle - If true, the indices are shuffled.
/// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \param[in] seed - The seed in use when shuffle is true.
/// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
/// If false the total rows returned by all the shards would not have overlap.
/// \return Shared pointer to the current Sampler.
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true,
int64_t num_samples = 0, uint32_t seed = 1);
int64_t num_samples = 0, uint32_t seed = 1,
bool even_dist = true);

/// Function to create a PK Sampler.
/// \notes Samples K elements for each P class in the dataset.
@@ -100,7 +103,8 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
/* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj {
public:
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed);
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
bool even_dist);

~DistributedSamplerObj() = default;

@@ -114,6 +118,7 @@ class DistributedSamplerObj : public SamplerObj {
bool shuffle_;
int64_t num_samples_;
uint32_t seed_;
bool even_dist_;
};

class PKSamplerObj : public SamplerObj {


+ 3
- 1
tests/ut/cpp/dataset/CMakeLists.txt View File

@@ -92,7 +92,9 @@ SET(DE_UT_SRCS
tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
epoch_ctrl_op_test.cc
swap_red_blue_test.cc
sentence_piece_vocab_op_test.cc
swap_red_blue_test.cc
distributed_sampler_test.cc
)

if (ENABLE_PYTHON)


+ 123
- 0
tests/ut/cpp/dataset/distributed_sampler_test.cc View File

@@ -0,0 +1,123 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "gtest/gtest.h"

#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "utils/log_adapter.h"

#include <vector>
#include <unordered_set>

using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;

class MindDataTestDistributedSampler : public UT::Common {
public:
class DummyRandomAccessOp : public RandomAccessOp {
public:
DummyRandomAccessOp(uint64_t num_rows) {
// row count is in base class as protected member
// GetNumRowsInDataset does not need an override, the default from base class is fine.
num_rows_ = num_rows;
}
};
};

TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) {
// num samples to draw.
uint64_t num_samples = 7;

// create sampler with replacement = true
DistributedSampler m_sampler(num_samples, 2, 0, false, 0, false);
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);

std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
out.push_back(*it);
}
}

ASSERT_EQ(4, out.size());

ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}

TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) {
// num samples to draw.
uint64_t num_samples = 7;

// create sampler with replacement = true
DistributedSampler m_sampler(num_samples, 2, 1, false, 0, false);
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);

std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
out.push_back(*it);
}
}

ASSERT_EQ(3, out.size());

ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}

TEST_F(MindDataTestDistributedSampler, TestThreeShards) {
// num samples to draw.
uint64_t num_samples = 2;

// create sampler with replacement = true
DistributedSampler m_sampler(num_samples, 3, 2, false, 0, false);
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);

std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
out.push_back(*it);
}
}

ASSERT_EQ(0, out.size());

ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}


Loading…
Cancel
Save