| @@ -48,6 +48,34 @@ namespace dataset { | |||||
| // Constructor | // Constructor | ||||
| SamplerObj::SamplerObj() {} | SamplerObj::SamplerObj() {} | ||||
| void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) { | |||||
| for (auto child : children_) { | |||||
| auto sampler_rt = child->Build(); | |||||
| sampler->AddChild(sampler_rt); | |||||
| } | |||||
| } | |||||
| Status SamplerObj::AddChild(std::shared_ptr<SamplerObj> child) { | |||||
| if (child == nullptr) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // Only samplers can be added, not any other DatasetOp. | |||||
| std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child); | |||||
| if (!sampler) { | |||||
| RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object."); | |||||
| } | |||||
| // Samplers can have at most 1 child. | |||||
| if (!children_.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child."); | |||||
| } | |||||
| children_.push_back(child); | |||||
| return Status::OK(); | |||||
| } | |||||
| /// Function to create a Distributed Sampler. | /// Function to create a Distributed Sampler. | ||||
| std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, | std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, | ||||
| int64_t num_samples, uint32_t seed, int64_t offset, | int64_t num_samples, uint32_t seed, int64_t offset, | ||||
| @@ -55,7 +83,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in | |||||
| auto sampler = | auto sampler = | ||||
| std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); | std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); | ||||
| // Input validation | // Input validation | ||||
| if (!sampler->ValidateParams()) { | |||||
| if (sampler->ValidateParams().IsError()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| @@ -65,7 +93,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in | |||||
| std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) { | std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) { | ||||
| auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples); | auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples); | ||||
| // Input validation | // Input validation | ||||
| if (!sampler->ValidateParams()) { | |||||
| if (sampler->ValidateParams().IsError()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| @@ -75,7 +103,7 @@ std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t n | |||||
| std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) { | std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) { | ||||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples); | auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples); | ||||
| // Input validation | // Input validation | ||||
| if (!sampler->ValidateParams()) { | |||||
| if (sampler->ValidateParams().IsError()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| @@ -85,7 +113,7 @@ std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_sa | |||||
| std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) { | std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) { | ||||
| auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples); | auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples); | ||||
| // Input validation | // Input validation | ||||
| if (!sampler->ValidateParams()) { | |||||
| if (sampler->ValidateParams().IsError()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| @@ -95,7 +123,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int | |||||
| std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) { | std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) { | ||||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples); | auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples); | ||||
| // Input validation | // Input validation | ||||
| if (!sampler->ValidateParams()) { | |||||
| if (sampler->ValidateParams().IsError()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| @@ -106,7 +134,7 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub | |||||
| bool replacement) { | bool replacement) { | ||||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement); | auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement); | ||||
| // Input validation | // Input validation | ||||
| if (!sampler->ValidateParams()) { | |||||
| if (sampler->ValidateParams().IsError()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| @@ -125,35 +153,33 @@ DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_i | |||||
| offset_(offset), | offset_(offset), | ||||
| even_dist_(even_dist) {} | even_dist_(even_dist) {} | ||||
| bool DistributedSamplerObj::ValidateParams() { | |||||
| Status DistributedSamplerObj::ValidateParams() { | |||||
| if (num_shards_ <= 0) { | if (num_shards_ <= 0) { | ||||
| MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_shards: " + std::to_string(num_shards_)); | |||||
| } | } | ||||
| if (shard_id_ < 0 || shard_id_ >= num_shards_) { | if (shard_id_ < 0 || shard_id_ >= num_shards_) { | ||||
| MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid input, shard_id: " + std::to_string(shard_id_) + | |||||
| ", num_shards: " + std::to_string(num_shards_)); | |||||
| } | } | ||||
| if (num_samples_ < 0) { | if (num_samples_ < 0) { | ||||
| MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||||
| } | } | ||||
| if (offset_ > num_shards_) { | if (offset_ > num_shards_) { | ||||
| MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_ | |||||
| << ", which should be no more than num_shards: " << num_shards_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid offset: " + std::to_string(offset_) + | |||||
| ", which should be no more than num_shards: " + std::to_string(num_shards_)); | |||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() { | std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() { | ||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | ||||
| offset_, even_dist_); | offset_, even_dist_); | ||||
| BuildChildren(sampler); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -170,23 +196,21 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa | |||||
| PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) | PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) | ||||
| : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} | : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} | ||||
| bool PKSamplerObj::ValidateParams() { | |||||
| Status PKSamplerObj::ValidateParams() { | |||||
| if (num_val_ <= 0) { | if (num_val_ <= 0) { | ||||
| MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_val: " + std::to_string(num_val_)); | |||||
| } | } | ||||
| if (num_samples_ < 0) { | if (num_samples_ < 0) { | ||||
| MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | ||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | ||||
| BuildChildren(sampler); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -198,9 +222,12 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator | |||||
| : sp_minddataset_(std::move(sampler)) {} | : sp_minddataset_(std::move(sampler)) {} | ||||
| #endif | #endif | ||||
| bool PreBuiltSamplerObj::ValidateParams() { return true; } | |||||
| Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } | |||||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; } | |||||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { | |||||
| BuildChildren(sp_); | |||||
| return sp_; | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | ||||
| @@ -208,9 +235,19 @@ std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDatas | |||||
| std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() { | std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() { | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | |||||
| if (sp_minddataset_ != nullptr) { | |||||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | |||||
| for (auto child : children_) { | |||||
| sampler->AddChild(child); | |||||
| } | |||||
| return sampler; | |||||
| } | |||||
| #endif | #endif | ||||
| return std::make_shared<PreBuiltSamplerObj>(sp_); | |||||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_); | |||||
| for (auto child : children_) { | |||||
| sampler->AddChild(child); | |||||
| } | |||||
| return sampler; | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -232,19 +269,18 @@ std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | |||||
| RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) | RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) | ||||
| : replacement_(replacement), num_samples_(num_samples) {} | : replacement_(replacement), num_samples_(num_samples) {} | ||||
| bool RandomSamplerObj::ValidateParams() { | |||||
| Status RandomSamplerObj::ValidateParams() { | |||||
| if (num_samples_ < 0) { | if (num_samples_ < 0) { | ||||
| MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("RandomSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::shared_ptr<SamplerRT> RandomSamplerObj::Build() { | std::shared_ptr<SamplerRT> RandomSamplerObj::Build() { | ||||
| // runtime sampler object | // runtime sampler object | ||||
| bool reshuffle_each_epoch = true; | bool reshuffle_each_epoch = true; | ||||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); | auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); | ||||
| BuildChildren(sampler); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -263,24 +299,22 @@ std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset | |||||
| SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) | SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) | ||||
| : start_index_(start_index), num_samples_(num_samples) {} | : start_index_(start_index), num_samples_(num_samples) {} | ||||
| bool SequentialSamplerObj::ValidateParams() { | |||||
| Status SequentialSamplerObj::ValidateParams() { | |||||
| if (num_samples_ < 0) { | if (num_samples_ < 0) { | ||||
| MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||||
| } | } | ||||
| if (start_index_ < 0) { | if (start_index_ < 0) { | ||||
| MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid start_index: " + std::to_string(start_index_)); | |||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() { | std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() { | ||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | ||||
| BuildChildren(sampler); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -297,19 +331,18 @@ std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDat | |||||
| SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | ||||
| : indices_(std::move(indices)), num_samples_(num_samples) {} | : indices_(std::move(indices)), num_samples_(num_samples) {} | ||||
| bool SubsetRandomSamplerObj::ValidateParams() { | |||||
| Status SubsetRandomSamplerObj::ValidateParams() { | |||||
| if (num_samples_ < 0) { | if (num_samples_ < 0) { | ||||
| MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() { | std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() { | ||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | ||||
| BuildChildren(sampler); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -326,34 +359,32 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD | |||||
| WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) | WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) | ||||
| : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | ||||
| bool WeightedRandomSamplerObj::ValidateParams() { | |||||
| Status WeightedRandomSamplerObj::ValidateParams() { | |||||
| if (weights_.empty()) { | if (weights_.empty()) { | ||||
| MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty"; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty"); | |||||
| } | } | ||||
| int32_t zero_elem = 0; | int32_t zero_elem = 0; | ||||
| for (int32_t i = 0; i < weights_.size(); ++i) { | for (int32_t i = 0; i < weights_.size(); ++i) { | ||||
| if (weights_[i] < 0) { | if (weights_[i] < 0) { | ||||
| MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i]; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " + | |||||
| std::to_string(weights_[i])); | |||||
| } | } | ||||
| if (weights_[i] == 0.0) { | if (weights_[i] == 0.0) { | ||||
| zero_elem++; | zero_elem++; | ||||
| } | } | ||||
| } | } | ||||
| if (zero_elem == weights_.size()) { | if (zero_elem == weights_.size()) { | ||||
| MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero"; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero"); | |||||
| } | } | ||||
| if (num_samples_ < 0) { | if (num_samples_ < 0) { | ||||
| MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; | |||||
| return false; | |||||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: invalid num_samples: " + std::to_string(num_samples_)); | |||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() { | std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() { | ||||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | ||||
| BuildChildren(sampler); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -37,6 +37,9 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, | |||||
| non_empty_(true) {} | non_empty_(true) {} | ||||
| Status DistributedSamplerRT::InitSampler() { | Status DistributedSamplerRT::InitSampler() { | ||||
| if (is_initialized) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | ||||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | ||||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | if (num_samples_ == 0 || num_samples_ > num_rows_) { | ||||
| @@ -72,6 +75,7 @@ Status DistributedSamplerRT::InitSampler() { | |||||
| } | } | ||||
| if (!samples_per_buffer_) non_empty_ = false; | if (!samples_per_buffer_) non_empty_ = false; | ||||
| is_initialized = true; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -28,6 +28,9 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t | |||||
| samples_per_class_(val) {} | samples_per_class_(val) {} | ||||
| Status PKSamplerRT::InitSampler() { | Status PKSamplerRT::InitSampler() { | ||||
| if (is_initialized) { | |||||
| return Status::OK(); | |||||
| } | |||||
| labels_.reserve(label_to_ids_.size()); | labels_.reserve(label_to_ids_.size()); | ||||
| for (const auto &pair : label_to_ids_) { | for (const auto &pair : label_to_ids_) { | ||||
| if (!pair.second.empty()) { | if (!pair.second.empty()) { | ||||
| @@ -58,6 +61,7 @@ Status PKSamplerRT::InitSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | CHECK_FAIL_RETURN_UNEXPECTED( | ||||
| num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " + | num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " + | ||||
| std::to_string(num_samples_)); | std::to_string(num_samples_)); | ||||
| is_initialized = true; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -65,6 +65,9 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| } | } | ||||
| Status PythonSamplerRT::InitSampler() { | Status PythonSamplerRT::InitSampler() { | ||||
| if (is_initialized) { | |||||
| return Status::OK(); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | CHECK_FAIL_RETURN_UNEXPECTED( | ||||
| num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); | num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); | ||||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | ||||
| @@ -83,6 +86,8 @@ Status PythonSamplerRT::InitSampler() { | |||||
| return Status(StatusCode::kPyFuncException, e.what()); | return Status(StatusCode::kPyFuncException, e.what()); | ||||
| } | } | ||||
| } | } | ||||
| is_initialized = true; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -69,6 +69,9 @@ Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| } | } | ||||
| Status RandomSamplerRT::InitSampler() { | Status RandomSamplerRT::InitSampler() { | ||||
| if (is_initialized) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | ||||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | ||||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | if (num_samples_ == 0 || num_samples_ > num_rows_) { | ||||
| @@ -91,6 +94,7 @@ Status RandomSamplerRT::InitSampler() { | |||||
| dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | ||||
| } | } | ||||
| is_initialized = true; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -34,7 +34,11 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||||
| } | } | ||||
| SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) | SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) | ||||
| : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | |||||
| : num_rows_(0), | |||||
| num_samples_(num_samples), | |||||
| samples_per_buffer_(samples_per_buffer), | |||||
| col_desc_(nullptr), | |||||
| is_initialized(false) {} | |||||
| Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { | Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { | ||||
| std::shared_ptr<SamplerRT> child_sampler; | std::shared_ptr<SamplerRT> child_sampler; | ||||
| @@ -160,6 +160,7 @@ class SamplerRT { | |||||
| // amount. | // amount. | ||||
| int64_t num_samples_; | int64_t num_samples_; | ||||
| bool is_initialized; | |||||
| int64_t samples_per_buffer_; | int64_t samples_per_buffer_; | ||||
| std::unique_ptr<ColDescriptor> col_desc_; | std::unique_ptr<ColDescriptor> col_desc_; | ||||
| std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes | std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes | ||||
| @@ -63,6 +63,9 @@ Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe | |||||
| } | } | ||||
| Status SequentialSamplerRT::InitSampler() { | Status SequentialSamplerRT::InitSampler() { | ||||
| if (is_initialized) { | |||||
| return Status::OK(); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, | CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, | ||||
| "Invalid parameter, start_index must be greater than or equal to 0, but got " + | "Invalid parameter, start_index must be greater than or equal to 0, but got " + | ||||
| std::to_string(start_index_) + ".\n"); | std::to_string(start_index_) + ".\n"); | ||||
| @@ -82,6 +85,8 @@ Status SequentialSamplerRT::InitSampler() { | |||||
| num_samples_ > 0 && samples_per_buffer_ > 0, | num_samples_ > 0 && samples_per_buffer_ > 0, | ||||
| "Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_)); | "Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_)); | ||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | ||||
| is_initialized = true; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -32,6 +32,9 @@ SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vec | |||||
| // Initialized this Sampler. | // Initialized this Sampler. | ||||
| Status SubsetRandomSamplerRT::InitSampler() { | Status SubsetRandomSamplerRT::InitSampler() { | ||||
| if (is_initialized) { | |||||
| return Status::OK(); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | CHECK_FAIL_RETURN_UNEXPECTED( | ||||
| num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); | num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); | ||||
| @@ -51,6 +54,7 @@ Status SubsetRandomSamplerRT::InitSampler() { | |||||
| // We will shuffle the full set of id's, but only select the first num_samples_ of them later. | // We will shuffle the full set of id's, but only select the first num_samples_ of them later. | ||||
| std::shuffle(indices_.begin(), indices_.end(), rand_gen_); | std::shuffle(indices_.begin(), indices_.end(), rand_gen_); | ||||
| is_initialized = true; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -37,6 +37,9 @@ WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std: | |||||
| // Initialized this Sampler. | // Initialized this Sampler. | ||||
| Status WeightedRandomSamplerRT::InitSampler() { | Status WeightedRandomSamplerRT::InitSampler() { | ||||
| if (is_initialized) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | ||||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | ||||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | if (num_samples_ == 0 || num_samples_ > num_rows_) { | ||||
| @@ -75,6 +78,7 @@ Status WeightedRandomSamplerRT::InitSampler() { | |||||
| discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end()); | discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end()); | ||||
| } | } | ||||
| is_initialized = true; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -22,7 +22,10 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/mindrecord/include/shard_column.h" | |||||
| #include "minddata/mindrecord/include/shard_error.h" | |||||
| #include "minddata/mindrecord/include/shard_reader.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -40,8 +43,8 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||||
| ~SamplerObj() = default; | ~SamplerObj() = default; | ||||
| /// \brief Pure virtual function for derived class to implement parameters validation | /// \brief Pure virtual function for derived class to implement parameters validation | ||||
| /// \return bool true if all the parameters are valid | |||||
| virtual bool ValidateParams() = 0; | |||||
| /// \return The Status code of the function. It returns OK status if parameters are valid. | |||||
| virtual Status ValidateParams() = 0; | |||||
| /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object | /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object | ||||
| /// \return Shared pointers to the newly created Sampler | /// \return Shared pointers to the newly created Sampler | ||||
| @@ -55,12 +58,24 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||||
| /// \return The shard id of the derived sampler | /// \return The shard id of the derived sampler | ||||
| virtual int64_t ShardId() { return 0; } | virtual int64_t ShardId() { return 0; } | ||||
| /// \brief Adds a child to the sampler | |||||
| /// \param[in] child The sampler to be added as child | |||||
| /// \return the Status code returned | |||||
| Status AddChild(std::shared_ptr<SamplerObj> child); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, | /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, | ||||
| /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler | /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler | ||||
| /// \return Shared pointers to the newly created Sampler | /// \return Shared pointers to the newly created Sampler | ||||
| virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; } | virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; } | ||||
| #endif | #endif | ||||
| protected: | |||||
| /// \brief A function that calls build on the children of this sampler | |||||
| /// \param[in] sampler The samplerRT object built from this sampler | |||||
| void BuildChildren(std::shared_ptr<SamplerRT> sampler); | |||||
| std::vector<std::shared_ptr<SamplerObj>> children_; | |||||
| }; | }; | ||||
| class DistributedSamplerObj; | class DistributedSamplerObj; | ||||
| @@ -137,15 +152,19 @@ class DistributedSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | std::shared_ptr<SamplerObj> Copy() override { | ||||
| return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, | |||||
| even_dist_); | |||||
| auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, | |||||
| offset_, even_dist_); | |||||
| for (auto child : children_) { | |||||
| sampler->AddChild(child); | |||||
| } | |||||
| return sampler; | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| bool ValidateParams() override; | |||||
| Status ValidateParams() override; | |||||
| /// \brief Function to get the shard id of sampler | /// \brief Function to get the shard id of sampler | ||||
| /// \return The shard id of sampler | /// \return The shard id of sampler | ||||
| @@ -170,14 +189,18 @@ class PKSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | std::shared_ptr<SamplerObj> Copy() override { | ||||
| return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||||
| auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||||
| for (auto child : children_) { | |||||
| sampler->AddChild(child); | |||||
| } | |||||
| return sampler; | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| bool ValidateParams() override; | |||||
| Status ValidateParams() override; | |||||
| private: | private: | ||||
| int64_t num_val_; | int64_t num_val_; | ||||
| @@ -202,7 +225,7 @@ class PreBuiltSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerObj> Copy() override; | std::shared_ptr<SamplerObj> Copy() override; | ||||
| bool ValidateParams() override; | |||||
| Status ValidateParams() override; | |||||
| private: | private: | ||||
| std::shared_ptr<SamplerRT> sp_; | std::shared_ptr<SamplerRT> sp_; | ||||
| @@ -219,13 +242,19 @@ class RandomSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); } | |||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_); | |||||
| for (auto child : children_) { | |||||
| sampler->AddChild(child); | |||||
| } | |||||
| return sampler; | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| bool ValidateParams() override; | |||||
| Status ValidateParams() override; | |||||
| private: | private: | ||||
| bool replacement_; | bool replacement_; | ||||
| @@ -241,14 +270,18 @@ class SequentialSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | std::shared_ptr<SamplerObj> Copy() override { | ||||
| return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||||
| auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||||
| for (auto child : children_) { | |||||
| sampler->AddChild(child); | |||||
| } | |||||
| return sampler; | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| bool ValidateParams() override; | |||||
| Status ValidateParams() override; | |||||
| private: | private: | ||||
| int64_t start_index_; | int64_t start_index_; | ||||
| @@ -264,14 +297,18 @@ class SubsetRandomSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | std::shared_ptr<SamplerObj> Copy() override { | ||||
| return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||||
| for (auto child : children_) { | |||||
| sampler->AddChild(child); | |||||
| } | |||||
| return sampler; | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| bool ValidateParams() override; | |||||
| Status ValidateParams() override; | |||||
| private: | private: | ||||
| const std::vector<int64_t> indices_; | const std::vector<int64_t> indices_; | ||||
| @@ -287,10 +324,14 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | std::shared_ptr<SamplerObj> Copy() override { | ||||
| return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||||
| for (auto child : children_) { | |||||
| sampler->AddChild(child); | |||||
| } | |||||
| return sampler; | |||||
| } | } | ||||
| bool ValidateParams() override; | |||||
| Status ValidateParams() override; | |||||
| private: | private: | ||||
| const std::vector<double> weights_; | const std::vector<double> weights_; | ||||
| @@ -208,6 +208,37 @@ TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestSamplerAddChild) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSamplerAddChild."; | |||||
| auto sampler = DistributedSampler(1, 0, false, 5, 0, -1, true); | |||||
| EXPECT_NE(sampler, nullptr); | |||||
| auto child_sampler = SequentialSampler(); | |||||
| sampler->AddChild(child_sampler); | |||||
| EXPECT_NE(child_sampler, nullptr); | |||||
| // Create an ImageFolder Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 5); | |||||
| iter->Stop(); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) { | TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) { | ||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail."; | ||||
| // Test invalid offset setting of distributed_sampler | // Test invalid offset setting of distributed_sampler | ||||