|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "minddata/dataset/include/samplers.h"
- #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
- #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
- #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
- #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
- #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
- #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
- #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
-
- #ifndef ENABLE_ANDROID
- #include "minddata/mindrecord/include/shard_distributed_sample.h"
- #include "minddata/mindrecord/include/shard_operator.h"
- #include "minddata/mindrecord/include/shard_pk_sample.h"
- #include "minddata/mindrecord/include/shard_sample.h"
- #include "minddata/mindrecord/include/shard_sequential_sample.h"
- #include "minddata/mindrecord/include/shard_shuffle.h"
- #include "minddata/dataset/util/random.h"
- #endif
-
- namespace mindspore {
- namespace dataset {
- namespace api {
-
- #define RETURN_NULL_IF_ERROR(_s) \
- do { \
- Status __rc = (_s); \
- if (__rc.IsError()) { \
- MS_LOG(ERROR) << __rc; \
- return nullptr; \
- } \
- } while (false)
-
- // Constructor
- SamplerObj::SamplerObj() {}
-
- /// Function to create a Distributed Sampler.
- std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
- int64_t num_samples, uint32_t seed, int64_t offset,
- bool even_dist) {
- auto sampler =
- std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
- // Input validation
- if (!sampler->ValidateParams()) {
- return nullptr;
- }
- return sampler;
- }
-
- /// Function to create a PK Sampler.
- std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) {
- auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
- // Input validation
- if (!sampler->ValidateParams()) {
- return nullptr;
- }
- return sampler;
- }
-
- /// Function to create a Random Sampler.
- std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) {
- auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples);
- // Input validation
- if (!sampler->ValidateParams()) {
- return nullptr;
- }
- return sampler;
- }
-
- /// Function to create a Sequential Sampler.
- std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) {
- auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
- // Input validation
- if (!sampler->ValidateParams()) {
- return nullptr;
- }
- return sampler;
- }
-
- /// Function to create a Subset Random Sampler.
- std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) {
- auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples);
- // Input validation
- if (!sampler->ValidateParams()) {
- return nullptr;
- }
- return sampler;
- }
-
- /// Function to create a Weighted Random Sampler.
- std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples,
- bool replacement) {
- auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement);
- // Input validation
- if (!sampler->ValidateParams()) {
- return nullptr;
- }
- return sampler;
- }
-
- /* ####################################### Derived Sampler classes ################################# */
-
- // DistributedSampler
- DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
- uint32_t seed, int64_t offset, bool even_dist)
- : num_shards_(num_shards),
- shard_id_(shard_id),
- shuffle_(shuffle),
- num_samples_(num_samples),
- seed_(seed),
- offset_(offset),
- even_dist_(even_dist) {}
-
- bool DistributedSamplerObj::ValidateParams() {
- if (num_shards_ <= 0) {
- MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_;
- return false;
- }
-
- if (shard_id_ < 0 || shard_id_ >= num_shards_) {
- MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
- return false;
- }
-
- if (num_samples_ < 0) {
- MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_;
- return false;
- }
-
- return true;
- }
-
- std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
- // runtime sampler object
- auto sampler = std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
- offset_, even_dist_);
- return sampler;
- }
-
- #ifndef ENABLE_ANDROID
- std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
- // runtime mindrecord sampler object
- auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
- num_samples_, offset_);
- return mind_sampler;
- }
- #endif
-
- // PKSampler
- PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
- : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
-
- bool PKSamplerObj::ValidateParams() {
- if (num_val_ <= 0) {
- MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_;
- return false;
- }
-
- if (num_samples_ < 0) {
- MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_;
- return false;
- }
- return true;
- }
-
- std::shared_ptr<Sampler> PKSamplerObj::Build() {
- // runtime sampler object
- auto sampler = std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_);
-
- return sampler;
- }
-
- #ifndef ENABLE_ANDROID
- std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
- // runtime mindrecord sampler object
- std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
- if (shuffle_ == true) {
- mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
- GetSeed(), num_samples_);
- } else {
- mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
- }
-
- return mind_sampler;
- }
- #endif
-
- // RandomSampler
- RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
- : replacement_(replacement), num_samples_(num_samples) {}
-
- bool RandomSamplerObj::ValidateParams() {
- if (num_samples_ < 0) {
- MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_;
- return false;
- }
- return true;
- }
-
- std::shared_ptr<Sampler> RandomSamplerObj::Build() {
- // runtime sampler object
- bool reshuffle_each_epoch = true;
- auto sampler = std::make_shared<dataset::RandomSampler>(num_samples_, replacement_, reshuffle_each_epoch);
-
- return sampler;
- }
-
- #ifndef ENABLE_ANDROID
- std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
- // runtime mindrecord sampler object
- bool reshuffle_each_epoch_ = true;
- auto mind_sampler =
- std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
-
- return mind_sampler;
- }
- #endif
-
- // SequentialSampler
- SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
- : start_index_(start_index), num_samples_(num_samples) {}
-
- bool SequentialSamplerObj::ValidateParams() {
- if (num_samples_ < 0) {
- MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_;
- return false;
- }
-
- if (start_index_ < 0) {
- MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_;
- return false;
- }
-
- return true;
- }
-
- std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
- // runtime sampler object
- auto sampler = std::make_shared<dataset::SequentialSampler>(num_samples_, start_index_);
-
- return sampler;
- }
-
- #ifndef ENABLE_ANDROID
- std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
- // runtime mindrecord sampler object
- auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
-
- return mind_sampler;
- }
- #endif
-
- // SubsetRandomSampler
- SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
- : indices_(std::move(indices)), num_samples_(num_samples) {}
-
- bool SubsetRandomSamplerObj::ValidateParams() {
- if (num_samples_ < 0) {
- MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_;
- return false;
- }
-
- return true;
- }
-
- std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
- // runtime sampler object
- auto sampler = std::make_shared<dataset::SubsetRandomSampler>(num_samples_, indices_);
-
- return sampler;
- }
-
- #ifndef ENABLE_ANDROID
- std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
- // runtime mindrecord sampler object
- auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
-
- return mind_sampler;
- }
- #endif
-
- // WeightedRandomSampler
- WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
- : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
-
- bool WeightedRandomSamplerObj::ValidateParams() {
- if (weights_.empty()) {
- MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty";
- return false;
- }
- int32_t zero_elem = 0;
- for (int32_t i = 0; i < weights_.size(); ++i) {
- if (weights_[i] < 0) {
- MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i];
- return false;
- }
- if (weights_[i] == 0.0) {
- zero_elem++;
- }
- }
- if (zero_elem == weights_.size()) {
- MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero";
- return false;
- }
- if (num_samples_ < 0) {
- MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_;
- return false;
- }
- return true;
- }
-
- std::shared_ptr<Sampler> WeightedRandomSamplerObj::Build() {
- auto sampler = std::make_shared<dataset::WeightedRandomSampler>(num_samples_, weights_, replacement_);
- return sampler;
- }
-
- } // namespace api
- } // namespace dataset
- } // namespace mindspore
|