From 9739d3b048cd2471b6032339def5e54e20f60f6f Mon Sep 17 00:00:00 2001 From: Junhan Hu Date: Sun, 29 Mar 2020 23:21:21 -0400 Subject: [PATCH] Add CPP sampler support for GeneratorDataset --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 2 +- .../ccsrc/dataset/api/python_bindings.cc | 11 +- mindspore/ccsrc/dataset/core/tensor.cc | 2 + .../engine/datasetops/source/celeba_op.cc | 2 +- .../engine/datasetops/source/cifar_op.cc | 2 +- .../datasetops/source/image_folder_op.cc | 2 +- .../engine/datasetops/source/manifest_op.cc | 2 +- .../engine/datasetops/source/mnist_op.cc | 2 +- .../source/sampler/distributed_sampler.cc | 5 +- .../source/sampler/distributed_sampler.h | 6 +- .../datasetops/source/sampler/pk_sampler.cc | 12 +- .../datasetops/source/sampler/pk_sampler.h | 5 +- .../source/sampler/random_sampler.cc | 5 +- .../source/sampler/random_sampler.h | 6 +- .../datasetops/source/sampler/sampler.cc | 49 ++++- .../datasetops/source/sampler/sampler.h | 18 +- .../source/sampler/sequential_sampler.cc | 4 +- .../source/sampler/sequential_sampler.h | 6 +- .../source/sampler/subset_random_sampler.cc | 5 +- .../source/sampler/subset_random_sampler.h | 3 +- .../source/sampler/weighted_random_sampler.cc | 32 +-- .../source/sampler/weighted_random_sampler.h | 5 +- .../engine/datasetops/source/voc_op.cc | 2 +- mindspore/dataset/engine/datasets.py | 194 ++++++++++++++---- mindspore/dataset/engine/samplers.py | 1 - mindspore/dataset/engine/validators.py | 44 ++-- .../cpp/dataset/stand_alone_samplers_test.cc | 4 +- .../cpp/dataset/subset_random_sampler_test.cc | 12 +- .../dataset/weighted_random_sampler_test.cc | 24 +-- tests/ut/python/dataset/test_generator.py | 71 +++++++ tests/ut/python/dataset/test_sampler.py | 21 ++ 31 files changed, 432 insertions(+), 127 deletions(-) diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index f572db0cdf..b64d40125e 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -517,7 +517,7 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr(obj)) { std::string err_msg = "Error: generator is invalid or not set."; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 3d543f946b..0633af4914 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -384,7 +384,16 @@ void bindTensorOps4(py::module *m) { } void bindSamplerOps(py::module *m) { - (void)py::class_>(*m, "Sampler"); + (void)py::class_>(*m, "Sampler") + .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) + .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) + .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) + .def("get_indices", [](Sampler &self) { + py::array ret; + THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); + return ret; + }); + (void)py::class_>(*m, "ShardOperator"); (void)py::class_>(*m, "DistributedSampler") diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc index 8f0eae459a..8fd1f8d48d 100644 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ b/mindspore/ccsrc/dataset/core/tensor.cc @@ -491,6 +491,8 @@ Status Tensor::GetItemAt(T *o, const std::vector &index) const { // return data as numpy, should return status Status Tensor::GetDataAsNumpy(py::array *data) { + RETURN_UNEXPECTED_IF_NULL(data_); + RETURN_UNEXPECTED_IF_NULL(data); if (type_ == DataType::DE_BOOL) { *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); } else if (type_ == DataType::DE_INT8) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc index 0c2e20729e..87a7b3c687 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc @@ -100,7 +100,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() { RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(ParseImageAttrInfo()); - RETURN_IF_NOT_OK(sampler_->Init(this)); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc index 3e64c8a3e6..60de5a6bdf 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc @@ -240,7 +240,7 @@ Status CifarOp::Reset() { // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows Status CifarOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->Init(this)); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc index f6cf377666..0ac579a865 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc @@ -258,7 +258,7 @@ Status ImageFolderOp::Reset() { // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows Status ImageFolderOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->Init(this)); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc index 6907647952..0139af4d9d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc @@ -254,7 +254,7 @@ Status ManifestOp::Reset() { // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows Status ManifestOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->Init(this)); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc index 3431e58aea..71900f8a91 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc @@ -205,7 +205,7 @@ Status MnistOp::Reset() { // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows Status MnistOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->Init(this)); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 28a5705648..5b5a9321df 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -31,8 +31,9 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu num_devices_(num_dev), shuffle_(shuffle) {} -Status DistributedSampler::Init(const RandomAccessOp *op) { - RETURN_IF_NOT_OK(Sampler::Init(op)); +Status DistributedSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, "fail to init DistributedSampler"); rnd_.seed(seed_++); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h index ef25b6bccf..58b469dcc8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -41,10 +41,8 @@ class DistributedSampler : public Sampler { // @return - The error code return Status GetNextBuffer(std::unique_ptr *out_buffer) override; - // first handshake between StorageOp and Sampler - // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() - // @return - Status Init(const RandomAccessOp *) override; + // Init sampler, called by base class or python + Status InitSampler() override; // for next epoch of sampleIds // @return - The error code return diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc index 8c8c12fce2..8198204437 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -28,9 +28,7 @@ PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer) num_pk_samples_(0), samples_per_class_(val) {} -Status PKSampler::Init(const RandomAccessOp *op) { - RETURN_UNEXPECTED_IF_NULL(op); - RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); +Status PKSampler::InitSampler() { labels_.reserve(label_to_ids_.size()); for (const auto &pair : label_to_ids_) { if (pair.second.empty() == false) { @@ -79,5 +77,13 @@ Status PKSampler::Reset() { rnd_.seed(seed_++); return Status::OK(); } + +Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { + RETURN_UNEXPECTED_IF_NULL(op); + RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); + RETURN_IF_NOT_OK(InitSampler()); + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h index fa2b4ed0c7..14f598a9ce 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -45,7 +45,10 @@ class PKSampler : public Sampler { // NOT YET FINISHED // first handshake between StorageOp and Sampler // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() // @return - Status Init(const RandomAccessOp *op) override; + Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; + + // init sampler, to be called by python or Handshake + Status InitSampler() override; // for next epoch of sampleIds // @return - The error code return diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc index 216f322052..de8cde409f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -49,10 +49,9 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { return Status::OK(); } -Status RandomSampler::Init(const RandomAccessOp *op) { - RETURN_IF_NOT_OK(Sampler::Init(op)); +Status RandomSampler::InitSampler() { num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "Fail to init RandomSampler"); + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; if (replacement_ == false) { shuffled_ids_.reserve(num_rows_); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h index 54f26f352b..84a07e9fc6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -42,10 +42,8 @@ class RandomSampler : public Sampler { // @return - The error code return Status GetNextBuffer(std::unique_ptr *out_buffer) override; - // first handshake between StorageOp and Sampler - // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() - // @return - Status Init(const RandomAccessOp *op) override; + // meant to be called by base class or python + Status InitSampler() override; // for next epoch of sampleIds // @return - The error code return diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index aa3838f8b5..3c3f5f48e8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -20,12 +20,13 @@ namespace dataset { Sampler::Sampler(int64_t samples_per_buffer) : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} -Status Sampler::Init(const RandomAccessOp *op) { - CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr && samples_per_buffer_ > 0, "Fail to init Sampler()\n"); +Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { + CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); // It's up to the derived class to check the validity of the two args // Because some sampler only needs one of the arg (weighted_random_sampler) + RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback return Status::OK(); } @@ -42,5 +43,49 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t (void)(*sample_ids)->StartAddr(); // allocate memory in case user forgets! return Status::OK(); } + +Status Sampler::GetAllIdsThenReset(py::array *data) { + std::unique_ptr db; + std::shared_ptr sample_ids; + + // check samples_per_buffer is properly set and doesn't overflow + CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid"); + + // A call to derived class to get sample ids wrapped inside a buffer + RETURN_IF_NOT_OK(GetNextBuffer(&db)); + // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch + RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); + // check this buffer is not a ctrl buffer + CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data)); + } catch (const std::runtime_error &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch + RETURN_IF_NOT_OK(GetNextBuffer(&db)); + CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); + // Reset Sampler since this is the end of the epoch + RETURN_IF_NOT_OK(Reset()); + return Status::OK(); +} + +Status Sampler::SetNumSamples(int64_t num_samples) { + CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0"); + num_samples_ = num_samples; + return Status::OK(); +} + +Status Sampler::SetNumRowsInDataset(int64_t num_rows) { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0"); + num_rows_ = num_rows; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h index 801565508b..4ea221027a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h @@ -78,14 +78,26 @@ class Sampler : public DatasetOp { // @return - The error code return Status GetNextBuffer(std::unique_ptr *out_buffer) override = 0; + // return all ids in one epoch as a numpy array, then call reset + Status GetAllIdsThenReset(py::array *data); + // for next epoch of sampleIds // @return - The error code return Status Reset() override = 0; - // first handshake between StorageOp and Sampler. Base class init will call both GetNumRows and GetNumSamples - // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() + // setter function for num_rows_ + Status SetNumRowsInDataset(int64_t num_rows); + + // setter function for num_samples_ + Status SetNumSamples(int64_t num_samples); + + // first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples + // @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds() // @return - virtual Status Init(const RandomAccessOp *op); + virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); + + // initialize sampler and perform checks on certain vars + virtual Status InitSampler() { return Status::OK(); } // Not meant to be called // @return diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index 72131a6de1..a3c4fe2256 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -41,9 +41,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) return Status::OK(); } -Status SequentialSampler::Init(const RandomAccessOp *op) { - RETURN_UNEXPECTED_IF_NULL(op); - RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); +Status SequentialSampler::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; return Status::OK(); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h index d119fd8d08..c38a9ed2f9 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -32,10 +32,8 @@ class SequentialSampler : public Sampler { // Destructor. ~SequentialSampler() = default; - // Initialize the sampler. - // @param op - // @return Status - Status Init(const RandomAccessOp *op) override; + // init sampler, called by python + Status InitSampler() override; // for next epoch of sampleIds // @return - The error code return diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index 16603939b3..c377fddb49 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -31,9 +31,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector &indices, in : Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} // Initialized this Sampler. -Status SubsetRandomSampler::Init(const RandomAccessOp *op) { - // Calling base class init. - RETURN_IF_NOT_OK(Sampler::Init(op)); +Status SubsetRandomSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); // Initialize random generator with seed from config manager rand_gen_.seed(GetSeed()); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h index 38fae6b20b..1f4c155748 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -38,9 +38,8 @@ class SubsetRandomSampler : public Sampler { ~SubsetRandomSampler() = default; // Initialize the sampler. - // @param op (Not used in this sampler) // @return Status - Status Init(const RandomAccessOp *op) override; + Status InitSampler() override; // Reset the internal variable to the initial state and reshuffle the indices. // @return Status diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index f2957e74be..06afc219e6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -29,21 +29,21 @@ namespace dataset { // Constructor. WeightedRandomSampler::WeightedRandomSampler(const std::vector &weights, int64_t num_samples, bool replacement, int64_t samples_per_buffer) - : Sampler(samples_per_buffer), weights_(weights), replacement_(replacement), sample_id_(0), buffer_id_(0) { - num_samples_ = num_samples; // this variable is defined in base class sampler -} + : Sampler(samples_per_buffer), + weights_(weights), + replacement_(replacement), + sample_id_(0), + buffer_id_(0), + user_num_samples_(num_samples) {} // Initialized this Sampler. -Status WeightedRandomSampler::Init(const RandomAccessOp *op) { - RETURN_UNEXPECTED_IF_NULL(op); - RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); - +Status WeightedRandomSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); // Initialize random generator with seed from config manager rand_gen_.seed(GetSeed()); - samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; - - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init WeightedRandomSampler"); + samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_; if (!replacement_) { exp_dist_ = std::make_unique>(1); @@ -65,8 +65,8 @@ void WeightedRandomSampler::InitOnePassSampling() { } // Partial sort the first `numSamples` elements. - std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); - for (int64_t i = 0; i < num_samples_; i++) { + std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end()); + for (int64_t i = 0; i < user_num_samples_; i++) { onepass_ids_.push_back(val_idx[i].second); } } @@ -91,11 +91,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr *out_buf "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); } - if (!replacement_ && (weights_.size() < static_cast(num_samples_))) { + if (!replacement_ && (weights_.size() < static_cast(user_num_samples_))) { RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); } - if (sample_id_ == num_samples_) { + if (sample_id_ == user_num_samples_) { (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); } else { (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); @@ -103,8 +103,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr *out_buf int64_t last_id = sample_id_ + samples_per_buffer_; // Handling the return all samples at once, and when last draw is not a full batch. - if (last_id > num_samples_) { - last_id = num_samples_; + if (last_id > user_num_samples_) { + last_id = user_num_samples_; } // Allocate tensor. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h index bccc9e599d..5381bb64b0 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -43,7 +43,7 @@ class WeightedRandomSampler : public Sampler { // Initialize the sampler. // @param op (Not used in this sampler) // @return Status - Status Init(const RandomAccessOp *op) override; + Status InitSampler() override; // Reset the internal variable to the initial state and reshuffle the indices. Status Reset() override; @@ -69,6 +69,9 @@ class WeightedRandomSampler : public Sampler { // Random engine and device std::mt19937 rand_gen_; + // num_samples from user + int64_t user_num_samples_; + // Discrete distribution for generating weighted random numbers with replacement. std::unique_ptr> discrete_dist_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc index 71b4c47cf5..1731ed14ba 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc @@ -220,7 +220,7 @@ Status VOCOp::ParseImageIds() { } Status VOCOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->Init(this)); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); return Status::OK(); } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8e6545375b..4480bbc462 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1748,14 +1748,70 @@ class MindDataset(SourceDataset): return num_rows -def ds_fn(dataset): - for val in dataset: - # convert output tensors to ndarrays - yield tuple([np.array(x) for x in val]) +def _iter_fn(dataset, num_samples): + """ + Generator function wrapper for iterable dataset + """ + if num_samples is not None: + ds_iter = iter(dataset) + for _ in range(num_samples): + try: + val = next(ds_iter) + except StopIteration: + return + # convert output tensors to ndarrays + yield tuple([np.array(x) for x in val]) + else: + for val in dataset: + # convert output tensors to ndarrays + yield tuple([np.array(x) for x in val]) + + +def _generator_fn(generator, num_samples): + """ + Generator function wrapper for generator function dataset + """ + if num_samples is not None: + gen_iter = generator() + for _ in range(num_samples): + try: + val = next(gen_iter) + except StopIteration: + return + yield val + else: + gen_iter = generator() + for val in gen_iter: + yield val -def sampler_fn(sampler, dataset): - for i in sampler: +def _py_sampler_fn(sampler, num_samples, dataset): + """ + Generator function wrapper for mappable dataset with python sampler + """ + if num_samples is not None: + sampler_iter = iter(sampler) + for _ in range(num_samples): + try: + idx = next(sampler_iter) + except StopIteration: + return + val = dataset[idx] + # convert output tensors to ndarrays + yield tuple([np.array(x) for x in val]) + else: + for i in sampler: + val = dataset[i] + # convert output tensors to ndarrays + yield tuple([np.array(x) for x in val]) + + +def _cpp_sampler_fn(sampler, dataset): + """ + Generator function wrapper for mappable dataset with cpp sampler + """ + indices = sampler.get_indices() + for i in indices: val = dataset[i] # convert output tensors to ndarrays yield tuple([np.array(x) for x in val]) @@ -1763,49 +1819,122 @@ def sampler_fn(sampler, dataset): class GeneratorDataset(SourceDataset): """ - A source dataset that generate data from calling generator function each epoch. + A source dataset that generate data from python by invoking python data source each epoch. + + This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table + below shows what input args are allowed and their expected behavior. + + .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle' + :widths: 25 25 50 + :header-rows: 1 + + * - Parameter 'sampler' + - Parameter 'shuffle' + - Expected Order Behavior + * - None + - None + - random order + * - None + - True + - random order + * - None + - False + - sequential order + * - Sampler object + - None + - order defined by sampler + * - Sampler object + - True + - not allowed + * - Sampler object + - False + - not allowed Args: - generator_function (callable): - A callable object that returns an Generator object that supports the iter() protocol. - Generator object is required to return a tuple of numpy array as a row of the dataset on next(). + source (Callable/Iterable/Random Accessible): + A generator callable object, an iterable python object or a random accessible python object. + Callable source is required to return a tuple of numpy array as a row of the dataset on source().next(). + Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next(). + Random accessible source is required to return a tuple of numpy array as a row of the dataset on + source[idx]. column_names (list[str]): List of column names of the dataset. column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None). If provided, sanity check will be performed on generator output. - prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None). - sampler (Sampler, optional): Object used to choose samples from the dataset (default=None). + schema (Schema/String, optional): Path to the json schema file or schema object (default=None). + If the schema is not provided, the meta data from column_names and column_types is considered the schema. + num_samples (int, optional): The number of samples to be included in the dataset + (default=None, all images). + shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. + (default=None, expected order behavior shown in the table). + sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is + required. + (default=None, expected order behavior shown in the table). + num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). + This argument should be specified only when 'num_samples' is "None". Random accessible input is required. + shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only + when num_shards is also specified. Random accessible input is required. Examples: - >>> import mindspore.dataset as ds - >>> # 1) generator function that generates multi-dimensional data + >>> import mindspore.dataengine as de + >>> # 1) Multidimensional generator function as callable input >>> def generator_md(): >>> for i in range(64): >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),) - >>> # create multi_dimension_generator_dataset with GeneratorMD() and column name "multi_dimensional_data" - >>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"]) - >>> # 2) generator function that generates multi-columns data + >>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data" + >>> multi_dimension_generator_dataset = de.GeneratorDataset(generator_md, ["multi_dimensional_data"]) + >>> # 2) Multi-column generator function as callable input >>> def generator_mc(maxid = 64): >>> for i in range(maxid): >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) - >>> # create multi_column_generator_dataset with GeneratorMC() and column names "col1" and "col2" - >>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1, col2"]) + >>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2" + >>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"]) + >>> # 3) Iterable dataset as iterable input + >>> class MyIterable(): + >>> def __iter__(self): + >>> return # User implementation + >>> # create iterable_generator_dataset with MyIterable object + >>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"]) + >>> # 4) Random accessible dataset as Random accessible input + >>> class MyRA(): + >>> def __getitem__(self, index): + >>> return # User implementation + >>> # create ra_generator_dataset with MyRA object + >>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"]) + >>> # List/Dict/Tuple is also random accessible + >>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"]) + >>> # 5) Built-in Sampler + >>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler()) + >>> """ @check_generatordataset - def __init__(self, generator_function, column_names, column_types=None, prefetch_size=None, sampler=None): - super().__init__(1) - if sampler is not None: - self.generator_function = (lambda: sampler_fn(sampler, generator_function)) + def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, + shuffle=None, sampler=None, num_shards=None, shard_id=None): + super().__init__(num_parallel_workers) + self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) + if self.sampler is not None and hasattr(source, "__getitem__"): + if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, + samplers.RandomSampler, samplers.SubsetRandomSampler, + samplers.WeightedRandomSampler)): + if num_samples is None: + num_samples = len(source) + sampler_instance = self.sampler.create() + sampler_instance.set_num_rows(len(source)) + sampler_instance.set_num_samples(num_samples) + sampler_instance.initialize() + self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) + else: + self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) else: try: - # test to see if generator_function is iterable - iter(generator_function) + iter(source) except TypeError: - # generator_function was not iterable, assume it is a function - self.generator_function = generator_function + # Use generator function if input callable + self.source = (lambda: _generator_fn(source, num_samples)) else: - # generator_function was iterable, build a function around it - self.generator_function = (lambda: ds_fn(generator_function)) + # Use iterator function if input is iterable + # Random accessible input is also iterable + self.source = (lambda: _iter_fn(source, num_samples)) self.column_names = column_names @@ -1813,17 +1942,12 @@ class GeneratorDataset(SourceDataset): self.column_types = mstypelist_to_detypelist(column_types) else: self.column_types = column_types - self.distribution = "" - self.prefetch_size = prefetch_size - self.sampler = sampler def get_args(self): args = super().get_args() - args["generator_function"] = self.generator_function + args["source"] = self.source args["column_names"] = self.column_names args["column_types"] = self.column_types - args["prefetch_size"] = self.prefetch_size - args["sampler"] = self.sampler return args def get_dataset_size(self): diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index fd9c50e951..f9c74f151d 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -20,7 +20,6 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. import mindspore._c_dataengine as cde - class DistributedSampler(): """ Sampler that access a shard of the dataset. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 63d7c58270..165a160e77 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -543,28 +543,48 @@ def check_generatordataset(method): def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) - nreq_param_int = ['prefetch_size'] - nreq_param_list = ['column_names', 'column_types'] - # check generator_function; required argument - generator_function = param_dict.get('generator_function') - if generator_function is None: - raise ValueError("generator_function is not provided.") + source = param_dict.get('source') + if source is None: + raise ValueError("source is not provided.") + if not callable(source): + try: + iter(source) + except TypeError: + raise TypeError("source should be callable, iterable or random accessible") # check column_names; required argument column_names = param_dict.get('column_names') if column_names is None: raise ValueError("column_names is not provided.") - # check prefetch_size range - prefetch_size = param_dict.get('prefetch_size') - if prefetch_size is not None and (prefetch_size <= 0 or prefetch_size > 1024): - raise ValueError("prefetch_size exceeds the boundary.") - + # check optional argument + nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"] check_param_type(nreq_param_int, param_dict, int) - + nreq_param_list = ["column_types"] check_param_type(nreq_param_list, param_dict, list) + num_shards = param_dict.get("num_shards") + shard_id = param_dict.get("shard_id") + if (num_shards is None) != (shard_id is None): + # These two parameters appear together. + raise ValueError("num_shards and shard_id need to be passed in together") + if num_shards is not None: + if shard_id >= num_shards: + raise ValueError("shard_id should be less than num_shards") + + sampler = param_dict.get("sampler") + if sampler is not None: + if isinstance(sampler, samplers.PKSampler): + raise ValueError("PKSampler is not supported by GeneratorDataset") + if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler, + samplers.RandomSampler, samplers.SubsetRandomSampler, + samplers.WeightedRandomSampler)): + try: + iter(sampler) + except TypeError: + raise TypeError("sampler should be either iterable or from dataset.samplers.py") + return method(*args, **kwargs) return new_method diff --git a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc index 48cc811615..ea0ae78aef 100644 --- a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc +++ b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc @@ -75,7 +75,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { std::shared_ptr tensor; for (int i = 0; i < 6; i++) { std::unique_ptr sampler = std::make_unique(3, i % 3, (i < 3 ? false : true)); - sampler->Init(&mock); + sampler->HandshakeRandomAccessOp(&mock); sampler->GetNextBuffer(&db); db->GetTensor(&tensor, 0, 0); MS_LOG(DEBUG) << (*tensor); @@ -95,7 +95,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { std::shared_ptr sampler = std::make_shared(3); std::unique_ptr db; std::shared_ptr tensor; - sampler->Init(&mock); + sampler->HandshakeRandomAccessOp(&mock); sampler->GetNextBuffer(&db); db->GetTensor(&tensor, 0, 0); EXPECT_TRUE((*tensor) == (*label1)); diff --git a/tests/ut/cpp/dataset/subset_random_sampler_test.cc b/tests/ut/cpp/dataset/subset_random_sampler_test.cc index 5142a6d399..bb8b3439d5 100644 --- a/tests/ut/cpp/dataset/subset_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/subset_random_sampler_test.cc @@ -52,8 +52,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { std::unordered_set in_set(in.begin(), in.end()); SubsetRandomSampler sampler(in); - DummyRandomAccessOp dummy_random_access_op(5); - sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(5); + sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; @@ -80,8 +80,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { std::vector input(total_samples, 1); SubsetRandomSampler sampler(input, samples_per_buffer); - DummyRandomAccessOp dummy_random_access_op(total_samples); - sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(total_samples); + sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; @@ -111,8 +111,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { std::unordered_set in_set(in.begin(), in.end()); SubsetRandomSampler sampler(in); - DummyRandomAccessOp dummy_random_access_op(5); - sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(5); + sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; diff --git a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc index 1c5d73613f..51a4bc3cb3 100644 --- a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc @@ -60,8 +60,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { // create sampler with replacement = true WeightedRandomSampler m_sampler(weights, num_samples, true); - DummyRandomAccessOp dummy_random_access_op(total_samples); - m_sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(total_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; @@ -90,8 +90,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { // create sampler with replacement = replacement WeightedRandomSampler m_sampler(weights, num_samples, false); - DummyRandomAccessOp dummy_random_access_op(total_samples); - m_sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(total_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; @@ -126,8 +126,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { // create sampler with replacement = replacement WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer); - DummyRandomAccessOp dummy_random_access_op(total_samples); - m_sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(total_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; @@ -162,8 +162,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { // create sampler with replacement = replacement WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer); - DummyRandomAccessOp dummy_random_access_op(total_samples); - m_sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(total_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; @@ -203,8 +203,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { // create sampler with replacement = true WeightedRandomSampler m_sampler(weights, num_samples, true); - DummyRandomAccessOp dummy_random_access_op(total_samples); - m_sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(total_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; @@ -248,8 +248,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { // create sampler with replacement = true WeightedRandomSampler m_sampler(weights, num_samples, false); - DummyRandomAccessOp dummy_random_access_op(total_samples); - m_sampler.Init(&dummy_random_access_op); + DummyRandomAccessOp dummyRandomAccessOp(total_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); std::unique_ptr db; TensorRow row; diff --git a/tests/ut/python/dataset/test_generator.py b/tests/ut/python/dataset/test_generator.py index 07556d9c7f..c224c5a2ea 100644 --- a/tests/ut/python/dataset/test_generator.py +++ b/tests/ut/python/dataset/test_generator.py @@ -439,6 +439,74 @@ def test_case_error_4(): assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) +def test_sequential_sampler(): + source = [(np.array([x]),) for x in range(64)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler()) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(data["data"], golden) + i = i + 1 + + +def test_random_sampler(): + source = [(np.array([x]),) for x in range(64)] + ds1 = ds.GeneratorDataset(source, ["data"], shuffle = True) + for data in ds1.create_dict_iterator(): # each data is a dictionary + pass + + +def test_distributed_sampler(): + source = [(np.array([x]),) for x in range(64)] + for sid in range(8): + ds1 = ds.GeneratorDataset(source, ["data"], shuffle = False, num_shards=8, shard_id=sid) + i = sid + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(data["data"], golden) + i = i + 8 + + +def test_num_samples(): + source = [(np.array([x]),) for x in range(64)] + num_samples = 32 + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples = num_samples) + ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples = num_samples) + ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples) + + count = 0 + for _ in ds1.create_dict_iterator(): + count = count + 1 + assert count == num_samples + + count = 0 + for _ in ds2.create_dict_iterator(): + count = count + 1 + assert count == num_samples + + count = 0 + for _ in ds3.create_dict_iterator(): + count = count + 1 + assert count == num_samples + + +def test_num_samples_underflow(): + source = [(np.array([x]),) for x in range(64)] + num_samples = 256 + ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples = num_samples) + ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples) + + count = 0 + for _ in ds2.create_dict_iterator(): + count = count + 1 + assert count == 64 + + count = 0 + for _ in ds3.create_dict_iterator(): + count = count + 1 + assert count == 64 + + if __name__ == "__main__": test_case_0() test_case_1() @@ -458,3 +526,6 @@ if __name__ == "__main__": test_case_error_2() test_case_error_3() test_case_error_4() + test_sequential_sampler() + test_distributed_sampler() + test_random_sampler() diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index ca618311cb..7a58249f9c 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -87,7 +87,28 @@ def test_random_sampler_multi_iter(print_res=False): test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5]) +def test_sampler_py_api(): + sampler = ds.SequentialSampler().create() + sampler.set_num_rows(128) + sampler.set_num_samples(64) + sampler.initialize() + sampler.get_indices() + + sampler = ds.RandomSampler().create() + sampler.set_num_rows(128) + sampler.set_num_samples(64) + sampler.initialize() + sampler.get_indices() + + sampler = ds.DistributedSampler(8, 4).create() + sampler.set_num_rows(128) + sampler.set_num_samples(64) + sampler.initialize() + sampler.get_indices() + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) test_random_sampler_multi_iter(True) + test_sampler_py_api()