Browse Source

Changed SamplerObj validate params to return status and added AddChild to it

tags/v1.2.0-rc1
Mahdi 5 years ago
parent
commit
0f2b5d8cac
12 changed files with 211 additions and 73 deletions
  1. +86
    -55
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  2. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  3. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc
  4. +5
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc
  5. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc
  6. +5
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc
  7. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h
  8. +5
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
  9. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
  10. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
  11. +58
    -17
      mindspore/ccsrc/minddata/dataset/include/samplers.h
  12. +31
    -0
      tests/ut/cpp/dataset/c_api_samplers_test.cc

+ 86
- 55
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -48,6 +48,34 @@ namespace dataset {
// Constructor // Constructor
SamplerObj::SamplerObj() {} SamplerObj::SamplerObj() {}


void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
for (auto child : children_) {
auto sampler_rt = child->Build();
sampler->AddChild(sampler_rt);
}
}

Status SamplerObj::AddChild(std::shared_ptr<SamplerObj> child) {
if (child == nullptr) {
return Status::OK();
}

// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
if (!sampler) {
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
}

// Samplers can have at most 1 child.
if (!children_.empty()) {
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
}

children_.push_back(child);

return Status::OK();
}

/// Function to create a Distributed Sampler. /// Function to create a Distributed Sampler.
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
int64_t num_samples, uint32_t seed, int64_t offset, int64_t num_samples, uint32_t seed, int64_t offset,
@@ -55,7 +83,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in
auto sampler = auto sampler =
std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
// Input validation // Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr; return nullptr;
} }
return sampler; return sampler;
@@ -65,7 +93,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in
std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) { std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) {
auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples); auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
// Input validation // Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr; return nullptr;
} }
return sampler; return sampler;
@@ -75,7 +103,7 @@ std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t n
std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) { std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) {
auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples); auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples);
// Input validation // Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr; return nullptr;
} }
return sampler; return sampler;
@@ -85,7 +113,7 @@ std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_sa
std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) { std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) {
auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples); auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
// Input validation // Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr; return nullptr;
} }
return sampler; return sampler;
@@ -95,7 +123,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) { std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples); auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples);
// Input validation // Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr; return nullptr;
} }
return sampler; return sampler;
@@ -106,7 +134,7 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub
bool replacement) { bool replacement) {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement); auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement);
// Input validation // Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr; return nullptr;
} }
return sampler; return sampler;
@@ -125,35 +153,33 @@ DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_i
offset_(offset), offset_(offset),
even_dist_(even_dist) {} even_dist_(even_dist) {}


bool DistributedSamplerObj::ValidateParams() {
Status DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) { if (num_shards_ <= 0) {
MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_;
return false;
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_shards: " + std::to_string(num_shards_));
} }


if (shard_id_ < 0 || shard_id_ >= num_shards_) { if (shard_id_ < 0 || shard_id_ >= num_shards_) {
MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
return false;
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid input, shard_id: " + std::to_string(shard_id_) +
", num_shards: " + std::to_string(num_shards_));
} }


if (num_samples_ < 0) { if (num_samples_ < 0) {
MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_samples: " + std::to_string(num_samples_));
} }


if (offset_ > num_shards_) { if (offset_ > num_shards_) {
MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_
<< ", which should be no more than num_shards: " << num_shards_;
return false;
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid offset: " + std::to_string(offset_) +
", which should be no more than num_shards: " + std::to_string(num_shards_));
} }


return true;
return Status::OK();
} }


std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() { std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_); offset_, even_dist_);
BuildChildren(sampler);
return sampler; return sampler;
} }


@@ -170,23 +196,21 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}


bool PKSamplerObj::ValidateParams() {
Status PKSamplerObj::ValidateParams() {
if (num_val_ <= 0) { if (num_val_ <= 0) {
MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_;
return false;
RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_val: " + std::to_string(num_val_));
} }


if (num_samples_ < 0) { if (num_samples_ < 0) {
MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_samples: " + std::to_string(num_samples_));
} }
return true;
return Status::OK();
} }


std::shared_ptr<SamplerRT> PKSamplerObj::Build() { std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
BuildChildren(sampler);
return sampler; return sampler;
} }


@@ -198,9 +222,12 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator
: sp_minddataset_(std::move(sampler)) {} : sp_minddataset_(std::move(sampler)) {}
#endif #endif


bool PreBuiltSamplerObj::ValidateParams() { return true; }
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }


std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() {
BuildChildren(sp_);
return sp_;
}


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
@@ -208,9 +235,19 @@ std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDatas


std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() { std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
if (sp_minddataset_ != nullptr) {
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#endif #endif
return std::make_shared<PreBuiltSamplerObj>(sp_);
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
} }


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@@ -232,19 +269,18 @@ std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
: replacement_(replacement), num_samples_(num_samples) {} : replacement_(replacement), num_samples_(num_samples) {}


bool RandomSamplerObj::ValidateParams() {
Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) { if (num_samples_ < 0) {
MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("RandomSampler: invalid num_samples: " + std::to_string(num_samples_));
} }
return true;
return Status::OK();
} }


std::shared_ptr<SamplerRT> RandomSamplerObj::Build() { std::shared_ptr<SamplerRT> RandomSamplerObj::Build() {
// runtime sampler object // runtime sampler object
bool reshuffle_each_epoch = true; bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
BuildChildren(sampler);
return sampler; return sampler;
} }


@@ -263,24 +299,22 @@ std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
: start_index_(start_index), num_samples_(num_samples) {} : start_index_(start_index), num_samples_(num_samples) {}


bool SequentialSamplerObj::ValidateParams() {
Status SequentialSamplerObj::ValidateParams() {
if (num_samples_ < 0) { if (num_samples_ < 0) {
MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid num_samples: " + std::to_string(num_samples_));
} }


if (start_index_ < 0) { if (start_index_ < 0) {
MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_;
return false;
RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid start_index: " + std::to_string(start_index_));
} }


return true;
return Status::OK();
} }


std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() { std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
BuildChildren(sampler);
return sampler; return sampler;
} }


@@ -297,19 +331,18 @@ std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDat
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {} : indices_(std::move(indices)), num_samples_(num_samples) {}


bool SubsetRandomSamplerObj::ValidateParams() {
Status SubsetRandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) { if (num_samples_ < 0) {
MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: invalid num_samples: " + std::to_string(num_samples_));
} }


return true;
return Status::OK();
} }


std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() { std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler; return sampler;
} }


@@ -326,34 +359,32 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}


bool WeightedRandomSamplerObj::ValidateParams() {
Status WeightedRandomSamplerObj::ValidateParams() {
if (weights_.empty()) { if (weights_.empty()) {
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty";
return false;
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
} }
int32_t zero_elem = 0; int32_t zero_elem = 0;
for (int32_t i = 0; i < weights_.size(); ++i) { for (int32_t i = 0; i < weights_.size(); ++i) {
if (weights_[i] < 0) { if (weights_[i] < 0) {
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i];
return false;
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
std::to_string(weights_[i]));
} }
if (weights_[i] == 0.0) { if (weights_[i] == 0.0) {
zero_elem++; zero_elem++;
} }
} }
if (zero_elem == weights_.size()) { if (zero_elem == weights_.size()) {
MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero";
return false;
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
} }
if (num_samples_ < 0) { if (num_samples_ < 0) {
MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: invalid num_samples: " + std::to_string(num_samples_));
} }
return true;
return Status::OK();
} }


std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() { std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);
return sampler; return sampler;
} }




+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -37,6 +37,9 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev,
non_empty_(true) {} non_empty_(true) {}


Status DistributedSamplerRT::InitSampler() { Status DistributedSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data. // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) { if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -72,6 +75,7 @@ Status DistributedSamplerRT::InitSampler() {
} }
if (!samples_per_buffer_) non_empty_ = false; if (!samples_per_buffer_) non_empty_ = false;


is_initialized = true;
return Status::OK(); return Status::OK();
} }




+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc View File

@@ -28,6 +28,9 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t
samples_per_class_(val) {} samples_per_class_(val) {}


Status PKSamplerRT::InitSampler() { Status PKSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
labels_.reserve(label_to_ids_.size()); labels_.reserve(label_to_ids_.size());
for (const auto &pair : label_to_ids_) { for (const auto &pair : label_to_ids_) {
if (!pair.second.empty()) { if (!pair.second.empty()) {
@@ -58,6 +61,7 @@ Status PKSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED( CHECK_FAIL_RETURN_UNEXPECTED(
num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " + num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " +
std::to_string(num_samples_)); std::to_string(num_samples_));
is_initialized = true;
return Status::OK(); return Status::OK();
} }




+ 5
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc View File

@@ -65,6 +65,9 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
} }


Status PythonSamplerRT::InitSampler() { Status PythonSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED( CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_));
// Special value of 0 for num_samples means that the user wants to sample the entire set of data. // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
@@ -83,6 +86,8 @@ Status PythonSamplerRT::InitSampler() {
return Status(StatusCode::kPyFuncException, e.what()); return Status(StatusCode::kPyFuncException, e.what());
} }
} }

is_initialized = true;
return Status::OK(); return Status::OK();
} }




+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc View File

@@ -69,6 +69,9 @@ Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
} }


Status RandomSamplerRT::InitSampler() { Status RandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data. // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) { if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -91,6 +94,7 @@ Status RandomSamplerRT::InitSampler() {
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
} }


is_initialized = true;
return Status::OK(); return Status::OK();
} }




+ 5
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc View File

@@ -34,7 +34,11 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
} }


SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer)
: num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
: num_rows_(0),
num_samples_(num_samples),
samples_per_buffer_(samples_per_buffer),
col_desc_(nullptr),
is_initialized(false) {}


Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<SamplerRT> child_sampler; std::shared_ptr<SamplerRT> child_sampler;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h View File

@@ -160,6 +160,7 @@ class SamplerRT {
// amount. // amount.
int64_t num_samples_; int64_t num_samples_;


bool is_initialized;
int64_t samples_per_buffer_; int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_; std::unique_ptr<ColDescriptor> col_desc_;
std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes


+ 5
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc View File

@@ -63,6 +63,9 @@ Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
} }


Status SequentialSamplerRT::InitSampler() { Status SequentialSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
"Invalid parameter, start_index must be greater than or equal to 0, but got " + "Invalid parameter, start_index must be greater than or equal to 0, but got " +
std::to_string(start_index_) + ".\n"); std::to_string(start_index_) + ".\n");
@@ -82,6 +85,8 @@ Status SequentialSamplerRT::InitSampler() {
num_samples_ > 0 && samples_per_buffer_ > 0, num_samples_ > 0 && samples_per_buffer_ > 0,
"Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_)); "Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_));
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;

is_initialized = true;
return Status::OK(); return Status::OK();
} }




+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc View File

@@ -32,6 +32,9 @@ SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vec


// Initialized this Sampler. // Initialized this Sampler.
Status SubsetRandomSamplerRT::InitSampler() { Status SubsetRandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED( CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");


@@ -51,6 +54,7 @@ Status SubsetRandomSamplerRT::InitSampler() {
// We will shuffle the full set of id's, but only select the first num_samples_ of them later. // We will shuffle the full set of id's, but only select the first num_samples_ of them later.
std::shuffle(indices_.begin(), indices_.end(), rand_gen_); std::shuffle(indices_.begin(), indices_.end(), rand_gen_);


is_initialized = true;
return Status::OK(); return Status::OK();
} }




+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc View File

@@ -37,6 +37,9 @@ WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std:


// Initialized this Sampler. // Initialized this Sampler.
Status WeightedRandomSamplerRT::InitSampler() { Status WeightedRandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data. // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) { if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -75,6 +78,7 @@ Status WeightedRandomSamplerRT::InitSampler() {
discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end()); discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end());
} }


is_initialized = true;
return Status::OK(); return Status::OK();
} }




+ 58
- 17
mindspore/ccsrc/minddata/dataset/include/samplers.h View File

@@ -22,7 +22,10 @@
#include <vector> #include <vector>


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_column.h"
#include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_reader.h"
#endif #endif


namespace mindspore { namespace mindspore {
@@ -40,8 +43,8 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
~SamplerObj() = default; ~SamplerObj() = default;


/// \brief Pure virtual function for derived class to implement parameters validation /// \brief Pure virtual function for derived class to implement parameters validation
/// \return bool true if all the parameters are valid
virtual bool ValidateParams() = 0;
/// \return The Status code of the function. It returns OK status if parameters are valid.
virtual Status ValidateParams() = 0;


/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler /// \return Shared pointers to the newly created Sampler
@@ -55,12 +58,24 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \return The shard id of the derived sampler /// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; } virtual int64_t ShardId() { return 0; }


/// \brief Adds a child to the sampler
/// \param[in] child The sampler to be added as child
/// \return the Status code returned
Status AddChild(std::shared_ptr<SamplerObj> child);

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler /// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; } virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
#endif #endif

protected:
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
void BuildChildren(std::shared_ptr<SamplerRT> sampler);

std::vector<std::shared_ptr<SamplerObj>> children_;
}; };


class DistributedSamplerObj; class DistributedSamplerObj;
@@ -137,15 +152,19 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override { std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_,
even_dist_);
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
offset_, even_dist_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
} }


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif


bool ValidateParams() override;
Status ValidateParams() override;


/// \brief Function to get the shard id of sampler /// \brief Function to get the shard id of sampler
/// \return The shard id of sampler /// \return The shard id of sampler
@@ -170,14 +189,18 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override { std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
} }


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif


bool ValidateParams() override;
Status ValidateParams() override;


private: private:
int64_t num_val_; int64_t num_val_;
@@ -202,7 +225,7 @@ class PreBuiltSamplerObj : public SamplerObj {


std::shared_ptr<SamplerObj> Copy() override; std::shared_ptr<SamplerObj> Copy() override;


bool ValidateParams() override;
Status ValidateParams() override;


private: private:
std::shared_ptr<SamplerRT> sp_; std::shared_ptr<SamplerRT> sp_;
@@ -219,13 +242,19 @@ class RandomSamplerObj : public SamplerObj {


std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); }
std::shared_ptr<SamplerObj> Copy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif


bool ValidateParams() override;
Status ValidateParams() override;


private: private:
bool replacement_; bool replacement_;
@@ -241,14 +270,18 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override { std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
} }


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif


bool ValidateParams() override;
Status ValidateParams() override;


private: private:
int64_t start_index_; int64_t start_index_;
@@ -264,14 +297,18 @@ class SubsetRandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override { std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
} }


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif


bool ValidateParams() override;
Status ValidateParams() override;


private: private:
const std::vector<int64_t> indices_; const std::vector<int64_t> indices_;
@@ -287,10 +324,14 @@ class WeightedRandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override { std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
} }


bool ValidateParams() override;
Status ValidateParams() override;


private: private:
const std::vector<double> weights_; const std::vector<double> weights_;


+ 31
- 0
tests/ut/cpp/dataset/c_api_samplers_test.cc View File

@@ -208,6 +208,37 @@ TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestSamplerAddChild) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSamplerAddChild.";

auto sampler = DistributedSampler(1, 0, false, 5, 0, -1, true);
EXPECT_NE(sampler, nullptr);

auto child_sampler = SequentialSampler();
sampler->AddChild(child_sampler);
EXPECT_NE(child_sampler, nullptr);

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
EXPECT_NE(ds, nullptr);

// Iterate the dataset and get each row
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);

uint64_t i = 0;
while (row.size() != 0) {
i++;
iter->GetNextRow(&row);
}

EXPECT_EQ(ds->GetDatasetSize(), 5);
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) { TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail."; MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail.";
// Test invalid offset setting of distributed_sampler // Test invalid offset setting of distributed_sampler


Loading…
Cancel
Save