| @@ -48,6 +48,34 @@ namespace dataset { | |||
| // Constructor | |||
| SamplerObj::SamplerObj() {} | |||
| void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) { | |||
| for (auto child : children_) { | |||
| auto sampler_rt = child->Build(); | |||
| sampler->AddChild(sampler_rt); | |||
| } | |||
| } | |||
| Status SamplerObj::AddChild(std::shared_ptr<SamplerObj> child) { | |||
| if (child == nullptr) { | |||
| return Status::OK(); | |||
| } | |||
| // Only samplers can be added, not any other DatasetOp. | |||
| std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child); | |||
| if (!sampler) { | |||
| RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object."); | |||
| } | |||
| // Samplers can have at most 1 child. | |||
| if (!children_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child."); | |||
| } | |||
| children_.push_back(child); | |||
| return Status::OK(); | |||
| } | |||
| /// Function to create a Distributed Sampler. | |||
| std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, | |||
| int64_t num_samples, uint32_t seed, int64_t offset, | |||
| @@ -55,7 +83,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in | |||
| auto sampler = | |||
| std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| if (sampler->ValidateParams().IsError()) { | |||
| return nullptr; | |||
| } | |||
| return sampler; | |||
| @@ -65,7 +93,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in | |||
| std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) { | |||
| auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| if (sampler->ValidateParams().IsError()) { | |||
| return nullptr; | |||
| } | |||
| return sampler; | |||
| @@ -75,7 +103,7 @@ std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t n | |||
| std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) { | |||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| if (sampler->ValidateParams().IsError()) { | |||
| return nullptr; | |||
| } | |||
| return sampler; | |||
| @@ -85,7 +113,7 @@ std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_sa | |||
| std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) { | |||
| auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| if (sampler->ValidateParams().IsError()) { | |||
| return nullptr; | |||
| } | |||
| return sampler; | |||
| @@ -95,7 +123,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int | |||
| std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) { | |||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| if (sampler->ValidateParams().IsError()) { | |||
| return nullptr; | |||
| } | |||
| return sampler; | |||
| @@ -106,7 +134,7 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub | |||
| bool replacement) { | |||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement); | |||
| // Input validation | |||
| if (!sampler->ValidateParams()) { | |||
| if (sampler->ValidateParams().IsError()) { | |||
| return nullptr; | |||
| } | |||
| return sampler; | |||
| @@ -125,35 +153,33 @@ DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_i | |||
| offset_(offset), | |||
| even_dist_(even_dist) {} | |||
| bool DistributedSamplerObj::ValidateParams() { | |||
| Status DistributedSamplerObj::ValidateParams() { | |||
| if (num_shards_ <= 0) { | |||
| MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_shards: " + std::to_string(num_shards_)); | |||
| } | |||
| if (shard_id_ < 0 || shard_id_ >= num_shards_) { | |||
| MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid input, shard_id: " + std::to_string(shard_id_) + | |||
| ", num_shards: " + std::to_string(num_shards_)); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||
| } | |||
| if (offset_ > num_shards_) { | |||
| MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_ | |||
| << ", which should be no more than num_shards: " << num_shards_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid offset: " + std::to_string(offset_) + | |||
| ", which should be no more than num_shards: " + std::to_string(num_shards_)); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||
| offset_, even_dist_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| @@ -170,23 +196,21 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa | |||
| PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) | |||
| : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} | |||
| bool PKSamplerObj::ValidateParams() { | |||
| Status PKSamplerObj::ValidateParams() { | |||
| if (num_val_ <= 0) { | |||
| MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_val: " + std::to_string(num_val_)); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| @@ -198,9 +222,12 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator | |||
| : sp_minddataset_(std::move(sampler)) {} | |||
| #endif | |||
| bool PreBuiltSamplerObj::ValidateParams() { return true; } | |||
| Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } | |||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; } | |||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { | |||
| BuildChildren(sp_); | |||
| return sp_; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | |||
| @@ -208,9 +235,19 @@ std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDatas | |||
| std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() { | |||
| #ifndef ENABLE_ANDROID | |||
| if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | |||
| if (sp_minddataset_ != nullptr) { | |||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | |||
| for (auto child : children_) { | |||
| sampler->AddChild(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #endif | |||
| return std::make_shared<PreBuiltSamplerObj>(sp_); | |||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_); | |||
| for (auto child : children_) { | |||
| sampler->AddChild(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -232,19 +269,18 @@ std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | |||
| RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) | |||
| : replacement_(replacement), num_samples_(num_samples) {} | |||
| bool RandomSamplerObj::ValidateParams() { | |||
| Status RandomSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("RandomSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> RandomSamplerObj::Build() { | |||
| // runtime sampler object | |||
| bool reshuffle_each_epoch = true; | |||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| @@ -263,24 +299,22 @@ std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset | |||
| SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) | |||
| : start_index_(start_index), num_samples_(num_samples) {} | |||
| bool SequentialSamplerObj::ValidateParams() { | |||
| Status SequentialSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||
| } | |||
| if (start_index_ < 0) { | |||
| MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid start_index: " + std::to_string(start_index_)); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| @@ -297,19 +331,18 @@ std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDat | |||
| SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | |||
| : indices_(std::move(indices)), num_samples_(num_samples) {} | |||
| bool SubsetRandomSamplerObj::ValidateParams() { | |||
| Status SubsetRandomSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| @@ -326,34 +359,32 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD | |||
| WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) | |||
| : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | |||
| bool WeightedRandomSamplerObj::ValidateParams() { | |||
| Status WeightedRandomSamplerObj::ValidateParams() { | |||
| if (weights_.empty()) { | |||
| MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty"; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty"); | |||
| } | |||
| int32_t zero_elem = 0; | |||
| for (int32_t i = 0; i < weights_.size(); ++i) { | |||
| if (weights_[i] < 0) { | |||
| MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i]; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " + | |||
| std::to_string(weights_[i])); | |||
| } | |||
| if (weights_[i] == 0.0) { | |||
| zero_elem++; | |||
| } | |||
| } | |||
| if (zero_elem == weights_.size()) { | |||
| MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero"; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero"); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; | |||
| return false; | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() { | |||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| @@ -37,6 +37,9 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, | |||
| non_empty_(true) {} | |||
| Status DistributedSamplerRT::InitSampler() { | |||
| if (is_initialized) { | |||
| return Status::OK(); | |||
| } | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| @@ -72,6 +75,7 @@ Status DistributedSamplerRT::InitSampler() { | |||
| } | |||
| if (!samples_per_buffer_) non_empty_ = false; | |||
| is_initialized = true; | |||
| return Status::OK(); | |||
| } | |||
| @@ -28,6 +28,9 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t | |||
| samples_per_class_(val) {} | |||
| Status PKSamplerRT::InitSampler() { | |||
| if (is_initialized) { | |||
| return Status::OK(); | |||
| } | |||
| labels_.reserve(label_to_ids_.size()); | |||
| for (const auto &pair : label_to_ids_) { | |||
| if (!pair.second.empty()) { | |||
| @@ -58,6 +61,7 @@ Status PKSamplerRT::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " + | |||
| std::to_string(num_samples_)); | |||
| is_initialized = true; | |||
| return Status::OK(); | |||
| } | |||
| @@ -65,6 +65,9 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| } | |||
| Status PythonSamplerRT::InitSampler() { | |||
| if (is_initialized) { | |||
| return Status::OK(); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| @@ -83,6 +86,8 @@ Status PythonSamplerRT::InitSampler() { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| } | |||
| is_initialized = true; | |||
| return Status::OK(); | |||
| } | |||
| @@ -69,6 +69,9 @@ Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| } | |||
| Status RandomSamplerRT::InitSampler() { | |||
| if (is_initialized) { | |||
| return Status::OK(); | |||
| } | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| @@ -91,6 +94,7 @@ Status RandomSamplerRT::InitSampler() { | |||
| dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | |||
| } | |||
| is_initialized = true; | |||
| return Status::OK(); | |||
| } | |||
| @@ -34,7 +34,11 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||
| } | |||
| SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) | |||
| : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | |||
| : num_rows_(0), | |||
| num_samples_(num_samples), | |||
| samples_per_buffer_(samples_per_buffer), | |||
| col_desc_(nullptr), | |||
| is_initialized(false) {} | |||
| Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| std::shared_ptr<SamplerRT> child_sampler; | |||
| @@ -160,6 +160,7 @@ class SamplerRT { | |||
| // amount. | |||
| int64_t num_samples_; | |||
| bool is_initialized; | |||
| int64_t samples_per_buffer_; | |||
| std::unique_ptr<ColDescriptor> col_desc_; | |||
| std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes | |||
| @@ -63,6 +63,9 @@ Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe | |||
| } | |||
| Status SequentialSamplerRT::InitSampler() { | |||
| if (is_initialized) { | |||
| return Status::OK(); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, | |||
| "Invalid parameter, start_index must be greater than or equal to 0, but got " + | |||
| std::to_string(start_index_) + ".\n"); | |||
| @@ -82,6 +85,8 @@ Status SequentialSamplerRT::InitSampler() { | |||
| num_samples_ > 0 && samples_per_buffer_ > 0, | |||
| "Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_)); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| is_initialized = true; | |||
| return Status::OK(); | |||
| } | |||
| @@ -32,6 +32,9 @@ SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vec | |||
| // Initialized this Sampler. | |||
| Status SubsetRandomSamplerRT::InitSampler() { | |||
| if (is_initialized) { | |||
| return Status::OK(); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); | |||
| @@ -51,6 +54,7 @@ Status SubsetRandomSamplerRT::InitSampler() { | |||
| // We will shuffle the full set of id's, but only select the first num_samples_ of them later. | |||
| std::shuffle(indices_.begin(), indices_.end(), rand_gen_); | |||
| is_initialized = true; | |||
| return Status::OK(); | |||
| } | |||
| @@ -37,6 +37,9 @@ WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std: | |||
| // Initialized this Sampler. | |||
| Status WeightedRandomSamplerRT::InitSampler() { | |||
| if (is_initialized) { | |||
| return Status::OK(); | |||
| } | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| @@ -75,6 +78,7 @@ Status WeightedRandomSamplerRT::InitSampler() { | |||
| discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end()); | |||
| } | |||
| is_initialized = true; | |||
| return Status::OK(); | |||
| } | |||
| @@ -22,7 +22,10 @@ | |||
| #include <vector> | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/mindrecord/include/shard_column.h" | |||
| #include "minddata/mindrecord/include/shard_error.h" | |||
| #include "minddata/mindrecord/include/shard_reader.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -40,8 +43,8 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||
| ~SamplerObj() = default; | |||
| /// \brief Pure virtual function for derived class to implement parameters validation | |||
| /// \return bool true if all the parameters are valid | |||
| virtual bool ValidateParams() = 0; | |||
| /// \return The Status code of the function. It returns OK status if parameters are valid. | |||
| virtual Status ValidateParams() = 0; | |||
| /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object | |||
| /// \return Shared pointers to the newly created Sampler | |||
| @@ -55,12 +58,24 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||
| /// \return The shard id of the derived sampler | |||
| virtual int64_t ShardId() { return 0; } | |||
| /// \brief Adds a child to the sampler | |||
| /// \param[in] child The sampler to be added as child | |||
| /// \return the Status code returned | |||
| Status AddChild(std::shared_ptr<SamplerObj> child); | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, | |||
| /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler | |||
| /// \return Shared pointers to the newly created Sampler | |||
| virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; } | |||
| #endif | |||
| protected: | |||
| /// \brief A function that calls build on the children of this sampler | |||
| /// \param[in] sampler The samplerRT object built from this sampler | |||
| void BuildChildren(std::shared_ptr<SamplerRT> sampler); | |||
| std::vector<std::shared_ptr<SamplerObj>> children_; | |||
| }; | |||
| class DistributedSamplerObj; | |||
| @@ -137,15 +152,19 @@ class DistributedSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, | |||
| even_dist_); | |||
| auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, | |||
| offset_, even_dist_); | |||
| for (auto child : children_) { | |||
| sampler->AddChild(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| bool ValidateParams() override; | |||
| Status ValidateParams() override; | |||
| /// \brief Function to get the shard id of sampler | |||
| /// \return The shard id of sampler | |||
| @@ -170,14 +189,18 @@ class PKSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||
| auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChild(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| bool ValidateParams() override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| int64_t num_val_; | |||
| @@ -202,7 +225,7 @@ class PreBuiltSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerObj> Copy() override; | |||
| bool ValidateParams() override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| std::shared_ptr<SamplerRT> sp_; | |||
| @@ -219,13 +242,19 @@ class RandomSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); } | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChild(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| bool ValidateParams() override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| bool replacement_; | |||
| @@ -241,14 +270,18 @@ class SequentialSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||
| auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChild(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| bool ValidateParams() override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| int64_t start_index_; | |||
| @@ -264,14 +297,18 @@ class SubsetRandomSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChild(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| bool ValidateParams() override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| const std::vector<int64_t> indices_; | |||
| @@ -287,10 +324,14 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||
| for (auto child : children_) { | |||
| sampler->AddChild(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| bool ValidateParams() override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| const std::vector<double> weights_; | |||
| @@ -208,6 +208,37 @@ TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) { | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestSamplerAddChild) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSamplerAddChild."; | |||
| auto sampler = DistributedSampler(1, 0, false, 5, 0, -1, true); | |||
| EXPECT_NE(sampler, nullptr); | |||
| auto child_sampler = SequentialSampler(); | |||
| sampler->AddChild(child_sampler); | |||
| EXPECT_NE(child_sampler, nullptr); | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(ds->GetDatasetSize(), 5); | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail."; | |||
| // Test invalid offset setting of distributed_sampler | |||