From: @mhmotallebi Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -97,6 +97,7 @@ add_dependencies(text-ir-kernels core) | |||
| add_dependencies(cpp-API core) | |||
| add_dependencies(engine-ir-datasetops core) | |||
| add_dependencies(engine-ir-datasetops-source core) | |||
| add_dependencies(engine-ir-datasetops-source-samplers core) | |||
| add_dependencies(engine-ir-cache core) | |||
| add_dependencies(kernels-ir core) | |||
| add_dependencies(kernels-ir-data core) | |||
| @@ -135,6 +136,7 @@ set(submodules | |||
| $<TARGET_OBJECTS:cpp-API> | |||
| $<TARGET_OBJECTS:engine-ir-datasetops> | |||
| $<TARGET_OBJECTS:engine-ir-datasetops-source> | |||
| $<TARGET_OBJECTS:engine-ir-datasetops-source-samplers> | |||
| $<TARGET_OBJECTS:engine-ir-cache> | |||
| $<TARGET_OBJECTS:kernels-soft-dvpp-image> | |||
| $<TARGET_OBJECTS:soft-dvpp-utils> | |||
| @@ -42,7 +42,7 @@ | |||
| #endif | |||
| // Sampler headers (in alphabetical order) | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| @@ -23,7 +23,7 @@ | |||
| #include "minddata/dataset/callback/py_ds_callback.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -15,69 +15,11 @@ | |||
| */ | |||
| #include "minddata/dataset/include/samplers.h" | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/mindrecord/include/shard_distributed_sample.h" | |||
| #include "minddata/mindrecord/include/shard_operator.h" | |||
| #include "minddata/mindrecord/include/shard_pk_sample.h" | |||
| #include "minddata/mindrecord/include/shard_sample.h" | |||
| #include "minddata/mindrecord/include/shard_sequential_sample.h" | |||
| #include "minddata/mindrecord/include/shard_shuffle.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| #define RETURN_NULL_IF_ERROR(_s) \ | |||
| do { \ | |||
| Status __rc = (_s); \ | |||
| if (__rc.IsError()) { \ | |||
| MS_LOG(ERROR) << __rc; \ | |||
| return nullptr; \ | |||
| } \ | |||
| } while (false) | |||
| // Constructor | |||
| SamplerObj::SamplerObj() {} | |||
| void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) { | |||
| for (auto child : children_) { | |||
| auto sampler_rt = child->SamplerBuild(); | |||
| sampler->AddChild(sampler_rt); | |||
| } | |||
| } | |||
| Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) { | |||
| if (child == nullptr) { | |||
| return Status::OK(); | |||
| } | |||
| // Only samplers can be added, not any other DatasetOp. | |||
| std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child); | |||
| if (!sampler) { | |||
| RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object."); | |||
| } | |||
| // Samplers can have at most 1 child. | |||
| if (!children_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child."); | |||
| } | |||
| children_.push_back(child); | |||
| return Status::OK(); | |||
| } | |||
| /// Function to create a Distributed Sampler. | |||
| std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, | |||
| int64_t num_samples, uint32_t seed, int64_t offset, | |||
| @@ -152,421 +94,5 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub | |||
| return sampler; | |||
| } | |||
| /* ####################################### Derived Sampler classes ################################# */ | |||
| // DistributedSampler | |||
| DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, | |||
| uint32_t seed, int64_t offset, bool even_dist) | |||
| : num_shards_(num_shards), | |||
| shard_id_(shard_id), | |||
| shuffle_(shuffle), | |||
| num_samples_(num_samples), | |||
| seed_(seed), | |||
| offset_(offset), | |||
| even_dist_(even_dist) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion | |||
| // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't | |||
| // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once | |||
| // PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| } | |||
| Status DistributedSamplerObj::ValidateParams() { | |||
| if (num_shards_ <= 0) { | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " + | |||
| std::to_string(num_shards_)); | |||
| } | |||
| if (shard_id_ < 0 || shard_id_ >= num_shards_) { | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) + | |||
| "), but got: " + std::to_string(shard_id_)); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| if (offset_ > num_shards_) { | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" + | |||
| std::to_string(num_shards_) + "), but got: " + std::to_string(offset_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||
| offset_, even_dist_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_, | |||
| num_samples_, offset_); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| Status DistributedSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "DistributedSampler"; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_samples"] = num_samples_; | |||
| args["offset"] = offset_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // PKSampler | |||
| PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) | |||
| : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} | |||
| Status PKSamplerObj::ValidateParams() { | |||
| if (num_val_ <= 0) { | |||
| RETURN_STATUS_UNEXPECTED("PKSampler: num_val must be greater than 0, but got: " + std::to_string(num_val_)); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("PKSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status PKSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "PKSampler"; | |||
| args["num_val"] = num_val_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| std::shared_ptr<mindrecord::ShardOperator> mind_sampler; | |||
| if (shuffle_ == true) { | |||
| mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(), | |||
| GetSeed(), num_samples_); | |||
| } else { | |||
| mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_); | |||
| } | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| // PreBuiltOperation | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {} | |||
| #ifndef ENABLE_ANDROID | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler) | |||
| : sp_minddataset_(std::move(sampler)) {} | |||
| #endif | |||
| Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } | |||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() { | |||
| BuildChildren(sp_); | |||
| return sp_; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | |||
| #endif | |||
| std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() { | |||
| #ifndef ENABLE_ANDROID | |||
| if (sp_minddataset_ != nullptr) { | |||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #endif | |||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) { | |||
| RETURN_IF_NOT_OK(sp_->to_json(out_json)); | |||
| return Status::OK(); | |||
| } | |||
| // RandomSampler | |||
| RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch) | |||
| : replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {} | |||
| Status RandomSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status RandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "RandomSampler"; | |||
| args["replacement"] = replacement_; | |||
| args["num_samples"] = num_samples_; | |||
| args["reshuffle_each_epoch"] = reshuffle_each_epoch_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = | |||
| std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| // SequentialSampler | |||
| SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) | |||
| : start_index_(start_index), num_samples_(num_samples) {} | |||
| Status SequentialSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| if (start_index_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " + | |||
| std::to_string(start_index_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SequentialSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SequentialSampler"; | |||
| args["start_index"] = start_index_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| // SubsetSampler | |||
| SubsetSamplerObj::SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | |||
| : indices_(std::move(indices)), num_samples_(num_samples) {} | |||
| Status SubsetSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| Status SubsetSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SubsetSampler"; | |||
| args["indices"] = indices_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // SubsetRandomSampler | |||
| SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | |||
| : SubsetSamplerObj(std::move(indices), num_samples) {} | |||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed()); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SubsetRandomSampler"; | |||
| args["indices"] = indices_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // WeightedRandomSampler | |||
| WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) | |||
| : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | |||
| Status WeightedRandomSamplerObj::ValidateParams() { | |||
| if (weights_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty"); | |||
| } | |||
| int32_t zero_elem = 0; | |||
| for (int32_t i = 0; i < weights_.size(); ++i) { | |||
| if (weights_[i] < 0) { | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " + | |||
| std::to_string(weights_[i])); | |||
| } | |||
| if (weights_[i] == 0.0) { | |||
| zero_elem++; | |||
| } | |||
| } | |||
| if (zero_elem == weights_.size()) { | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero"); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "WeightedRandomSampler"; | |||
| args["weights"] = weights_; | |||
| args["num_samples"] = num_samples_; | |||
| args["replacement"] = replacement_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() { | |||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,6 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_subdirectory(samplers) | |||
| set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | |||
| album_node.cc | |||
| @@ -0,0 +1,8 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES | |||
| samplers_ir.cc | |||
| ) | |||
| add_library(engine-ir-datasetops-source-samplers OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES}) | |||
| @@ -0,0 +1,490 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/mindrecord/include/shard_distributed_sample.h" | |||
| #include "minddata/mindrecord/include/shard_operator.h" | |||
| #include "minddata/mindrecord/include/shard_pk_sample.h" | |||
| #include "minddata/mindrecord/include/shard_sample.h" | |||
| #include "minddata/mindrecord/include/shard_sequential_sample.h" | |||
| #include "minddata/mindrecord/include/shard_shuffle.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| SamplerObj::SamplerObj() {} | |||
| void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) { | |||
| for (auto child : children_) { | |||
| auto sampler_rt = child->SamplerBuild(); | |||
| sampler->AddChild(sampler_rt); | |||
| } | |||
| } | |||
| Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) { | |||
| if (child == nullptr) { | |||
| return Status::OK(); | |||
| } | |||
| // Only samplers can be added, not any other DatasetOp. | |||
| std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child); | |||
| if (!sampler) { | |||
| RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object."); | |||
| } | |||
| // Samplers can have at most 1 child. | |||
| if (!children_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child."); | |||
| } | |||
| children_.push_back(child); | |||
| return Status::OK(); | |||
| } | |||
| /* ####################################### Derived Sampler classes ################################# */ | |||
| // DistributedSampler | |||
| DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, | |||
| uint32_t seed, int64_t offset, bool even_dist) | |||
| : num_shards_(num_shards), | |||
| shard_id_(shard_id), | |||
| shuffle_(shuffle), | |||
| num_samples_(num_samples), | |||
| seed_(seed), | |||
| offset_(offset), | |||
| even_dist_(even_dist) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion | |||
| // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't | |||
| // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once | |||
| // PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| } | |||
| Status DistributedSamplerObj::ValidateParams() { | |||
| if (num_shards_ <= 0) { | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " + | |||
| std::to_string(num_shards_)); | |||
| } | |||
| if (shard_id_ < 0 || shard_id_ >= num_shards_) { | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) + | |||
| "), but got: " + std::to_string(shard_id_)); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| if (offset_ > num_shards_) { | |||
| RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" + | |||
| std::to_string(num_shards_) + "), but got: " + std::to_string(offset_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||
| offset_, even_dist_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_, | |||
| num_samples_, offset_); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| Status DistributedSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "DistributedSampler"; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_samples"] = num_samples_; | |||
| args["offset"] = offset_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // PKSampler | |||
| PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) | |||
| : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} | |||
| Status PKSamplerObj::ValidateParams() { | |||
| if (num_val_ <= 0) { | |||
| RETURN_STATUS_UNEXPECTED("PKSampler: num_val must be greater than 0, but got: " + std::to_string(num_val_)); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("PKSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status PKSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "PKSampler"; | |||
| args["num_val"] = num_val_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| std::shared_ptr<mindrecord::ShardOperator> mind_sampler; | |||
| if (shuffle_ == true) { | |||
| mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(), | |||
| GetSeed(), num_samples_); | |||
| } else { | |||
| mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_); | |||
| } | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| // PreBuiltOperation | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {} | |||
| #ifndef ENABLE_ANDROID | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler) | |||
| : sp_minddataset_(std::move(sampler)) {} | |||
| #endif | |||
| Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } | |||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() { | |||
| BuildChildren(sp_); | |||
| return sp_; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | |||
| #endif | |||
| std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() { | |||
| #ifndef ENABLE_ANDROID | |||
| if (sp_minddataset_ != nullptr) { | |||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #endif | |||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) { | |||
| RETURN_IF_NOT_OK(sp_->to_json(out_json)); | |||
| return Status::OK(); | |||
| } | |||
| // RandomSampler | |||
| RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch) | |||
| : replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {} | |||
| Status RandomSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status RandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "RandomSampler"; | |||
| args["replacement"] = replacement_; | |||
| args["num_samples"] = num_samples_; | |||
| args["reshuffle_each_epoch"] = reshuffle_each_epoch_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = | |||
| std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| // SequentialSampler | |||
| SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) | |||
| : start_index_(start_index), num_samples_(num_samples) {} | |||
| Status SequentialSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| if (start_index_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " + | |||
| std::to_string(start_index_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SequentialSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SequentialSampler"; | |||
| args["start_index"] = start_index_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| // SubsetSampler | |||
| SubsetSamplerObj::SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | |||
| : indices_(std::move(indices)), num_samples_(num_samples) {} | |||
| Status SubsetSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| Status SubsetSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SubsetSampler"; | |||
| args["indices"] = indices_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // SubsetRandomSampler | |||
| SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | |||
| : SubsetSamplerObj(std::move(indices), num_samples) {} | |||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed()); | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SubsetRandomSampler"; | |||
| args["indices"] = indices_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // WeightedRandomSampler | |||
| WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) | |||
| : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | |||
| Status WeightedRandomSamplerObj::ValidateParams() { | |||
| if (weights_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty"); | |||
| } | |||
| int32_t zero_elem = 0; | |||
| for (int32_t i = 0; i < weights_.size(); ++i) { | |||
| if (weights_[i] < 0) { | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " + | |||
| std::to_string(weights_[i])); | |||
| } | |||
| if (weights_[i] == 0.0) { | |||
| zero_elem++; | |||
| } | |||
| } | |||
| if (zero_elem == weights_.size()) { | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero"); | |||
| } | |||
| if (num_samples_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " + | |||
| std::to_string(num_samples_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "WeightedRandomSampler"; | |||
| args["weights"] = weights_; | |||
| args["num_samples"] = num_samples_; | |||
| args["replacement"] = replacement_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() { | |||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,344 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_ | |||
| #include <limits> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <nlohmann/json.hpp> | |||
| #include "include/api/status.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/mindrecord/include/shard_operator.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Internal Sampler class forward declaration | |||
| class SamplerRT; | |||
| class SamplerObj { | |||
| public: | |||
| /// \brief Constructor | |||
| SamplerObj(); | |||
| /// \brief Destructor | |||
| ~SamplerObj() = default; | |||
| /// \brief Pure virtual function for derived class to implement parameters validation | |||
| /// \return The Status code of the function. It returns OK status if parameters are valid. | |||
| virtual Status ValidateParams() = 0; | |||
| /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object | |||
| /// \return Shared pointers to the newly created Sampler | |||
| virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0; | |||
| /// \brief Pure virtual function to copy a SamplerObj class | |||
| /// \return Shared pointers to the newly copied SamplerObj | |||
| virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0; | |||
| /// \brief Function for derived class to get the shard id of sampler | |||
| /// \return The shard id of the derived sampler | |||
| virtual int64_t ShardId() { return 0; } | |||
| /// \brief Adds a child to the sampler | |||
| /// \param[in] child The sampler to be added as child | |||
| /// \return the Status code returned | |||
| Status AddChildSampler(std::shared_ptr<SamplerObj> child); | |||
| virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } | |||
| std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; } | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, | |||
| /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler | |||
| /// \return Shared pointers to the newly created Sampler | |||
| virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; } | |||
| #endif | |||
| protected: | |||
| /// \brief A function that calls build on the children of this sampler | |||
| /// \param[in] sampler The samplerRT object built from this sampler | |||
| void BuildChildren(std::shared_ptr<SamplerRT> sampler); | |||
| std::vector<std::shared_ptr<SamplerObj>> children_; | |||
| }; | |||
| /* ####################################### Derived Sampler classes ################################# */ | |||
| class DistributedSamplerObj : public SamplerObj { | |||
| public: | |||
| DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed, | |||
| int64_t offset, bool even_dist); | |||
| ~DistributedSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, | |||
| offset_, even_dist_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| /// \brief Function to get the shard id of sampler | |||
| /// \return The shard id of sampler | |||
| int64_t ShardId() override { return shard_id_; } | |||
| private: | |||
| int64_t num_shards_; | |||
| int64_t shard_id_; | |||
| bool shuffle_; | |||
| int64_t num_samples_; | |||
| uint32_t seed_; | |||
| int64_t offset_; | |||
| bool even_dist_; | |||
| }; | |||
| class PKSamplerObj : public SamplerObj { | |||
| public: | |||
| PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples); | |||
| ~PKSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| int64_t num_val_; | |||
| bool shuffle_; | |||
| int64_t num_samples_; | |||
| }; | |||
| class PreBuiltSamplerObj : public SamplerObj { | |||
| public: | |||
| explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler); | |||
| #ifndef ENABLE_ANDROID | |||
| explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler); | |||
| #endif | |||
| ~PreBuiltSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override; | |||
| Status ValidateParams() override; | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::shared_ptr<SamplerRT> sp_; | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_; | |||
| #endif | |||
| }; | |||
| class RandomSamplerObj : public SamplerObj { | |||
| public: | |||
| RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true); | |||
| ~RandomSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| bool replacement_; | |||
| int64_t num_samples_; | |||
| bool reshuffle_each_epoch_; | |||
| }; | |||
| class SequentialSamplerObj : public SamplerObj { | |||
| public: | |||
| SequentialSamplerObj(int64_t start_index, int64_t num_samples); | |||
| ~SequentialSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| int64_t start_index_; | |||
| int64_t num_samples_; | |||
| }; | |||
| class SubsetSamplerObj : public SamplerObj { | |||
| public: | |||
| SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples); | |||
| ~SubsetSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| protected: | |||
| const std::vector<int64_t> indices_; | |||
| int64_t num_samples_; | |||
| }; | |||
| class SubsetRandomSamplerObj : public SubsetSamplerObj { | |||
| public: | |||
| SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples); | |||
| ~SubsetRandomSamplerObj() = default; | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| private: | |||
| }; | |||
| class WeightedRandomSamplerObj : public SamplerObj { | |||
| public: | |||
| explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true); | |||
| ~WeightedRandomSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| const std::vector<double> weights_; | |||
| int64_t num_samples_; | |||
| bool replacement_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_ | |||
| @@ -16,6 +16,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_SERDES_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_SERDES_H_ | |||
| #include <fstream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -18,72 +18,14 @@ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <nlohmann/json.hpp> | |||
| #include "include/api/status.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/mindrecord/include/shard_column.h" | |||
| #include "minddata/mindrecord/include/shard_error.h" | |||
| #include "minddata/mindrecord/include/shard_operator.h" | |||
| #include "minddata/mindrecord/include/shard_reader.h" | |||
| #endif | |||
| // FIXME - This internal IR header will be removed when external API classes are provided | |||
| #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Internal Sampler class forward declaration | |||
| class SamplerRT; | |||
| class SamplerObj { | |||
| public: | |||
| /// \brief Constructor | |||
| SamplerObj(); | |||
| /// \brief Destructor | |||
| ~SamplerObj() = default; | |||
| /// \brief Pure virtual function for derived class to implement parameters validation | |||
| /// \return The Status code of the function. It returns OK status if parameters are valid. | |||
| virtual Status ValidateParams() = 0; | |||
| /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object | |||
| /// \return Shared pointers to the newly created Sampler | |||
| virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0; | |||
| /// \brief Pure virtual function to copy a SamplerObj class | |||
| /// \return Shared pointers to the newly copied SamplerObj | |||
| virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0; | |||
| /// \brief Function for derived class to get the shard id of sampler | |||
| /// \return The shard id of the derived sampler | |||
| virtual int64_t ShardId() { return 0; } | |||
| /// \brief Adds a child to the sampler | |||
| /// \param[in] child The sampler to be added as child | |||
| /// \return the Status code returned | |||
| Status AddChildSampler(std::shared_ptr<SamplerObj> child); | |||
| virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } | |||
| std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; } | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, | |||
| /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler | |||
| /// \return Shared pointers to the newly created Sampler | |||
| virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; } | |||
| #endif | |||
| protected: | |||
| /// \brief A function that calls build on the children of this sampler | |||
| /// \param[in] sampler The samplerRT object built from this sampler | |||
| void BuildChildren(std::shared_ptr<SamplerRT> sampler); | |||
| std::vector<std::shared_ptr<SamplerObj>> children_; | |||
| }; | |||
| class DistributedSamplerObj; | |||
| class PKSamplerObj; | |||
| class PreBuiltSamplerObj; | |||
| @@ -155,261 +97,6 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> | |||
| std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, | |||
| bool replacement = true); | |||
| /* ####################################### Derived Sampler classes ################################# */ | |||
| class DistributedSamplerObj : public SamplerObj { | |||
| public: | |||
| DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed, | |||
| int64_t offset, bool even_dist); | |||
| ~DistributedSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, | |||
| offset_, even_dist_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| /// \brief Function to get the shard id of sampler | |||
| /// \return The shard id of sampler | |||
| int64_t ShardId() override { return shard_id_; } | |||
| private: | |||
| int64_t num_shards_; | |||
| int64_t shard_id_; | |||
| bool shuffle_; | |||
| int64_t num_samples_; | |||
| uint32_t seed_; | |||
| int64_t offset_; | |||
| bool even_dist_; | |||
| }; | |||
| class PKSamplerObj : public SamplerObj { | |||
| public: | |||
| PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples); | |||
| ~PKSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| int64_t num_val_; | |||
| bool shuffle_; | |||
| int64_t num_samples_; | |||
| }; | |||
| class PreBuiltSamplerObj : public SamplerObj { | |||
| public: | |||
| explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler); | |||
| #ifndef ENABLE_ANDROID | |||
| explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler); | |||
| #endif | |||
| ~PreBuiltSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override; | |||
| Status ValidateParams() override; | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::shared_ptr<SamplerRT> sp_; | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_; | |||
| #endif | |||
| }; | |||
| class RandomSamplerObj : public SamplerObj { | |||
| public: | |||
| RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true); | |||
| ~RandomSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| bool replacement_; | |||
| int64_t num_samples_; | |||
| bool reshuffle_each_epoch_; | |||
| }; | |||
| class SequentialSamplerObj : public SamplerObj { | |||
| public: | |||
| SequentialSamplerObj(int64_t start_index, int64_t num_samples); | |||
| ~SequentialSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| int64_t start_index_; | |||
| int64_t num_samples_; | |||
| }; | |||
| class SubsetSamplerObj : public SamplerObj { | |||
| public: | |||
| SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples); | |||
| ~SubsetSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| protected: | |||
| const std::vector<int64_t> indices_; | |||
| int64_t num_samples_; | |||
| }; | |||
| class SubsetRandomSamplerObj : public SubsetSamplerObj { | |||
| public: | |||
| SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples); | |||
| ~SubsetRandomSamplerObj() = default; | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| private: | |||
| }; | |||
| class WeightedRandomSamplerObj : public SamplerObj { | |||
| public: | |||
| explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true); | |||
| ~WeightedRandomSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| return sampler; | |||
| } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| const std::vector<double> weights_; | |||
| int64_t num_samples_; | |||
| bool replacement_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ | |||
| @@ -136,6 +136,7 @@ if(BUILD_MINDDATA STREQUAL "full") | |||
| ${MINDDATA_DIR}/engine/ir/datasetops/shuffle_node.cc | |||
| ${MINDDATA_DIR}/engine/ir/datasetops/source/album_node.cc | |||
| ${MINDDATA_DIR}/engine/ir/datasetops/source/mnist_node.cc | |||
| ${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/samplers_ir.cc | |||
| ${MINDDATA_DIR}/engine/datasetops/dataset_op.cc | |||
| ${MINDDATA_DIR}/engine/datasetops/repeat_op.cc | |||
| ${MINDDATA_DIR}/engine/datasetops/epoch_ctrl_op.cc | |||