diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc index b3288a9ee4..8a94ab4a5c 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc @@ -184,21 +184,21 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) { PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { (void)py::class_>( *m, "GeneratorNode", "to create a GeneratorNode") - .def(py::init([](py::function generator_function, const std::vector &column_names, - const std::vector &column_types) { - auto gen = std::make_shared(generator_function, column_names, column_types); - THROW_IF_ERROR(gen->ValidateParams()); - return gen; - })) - .def(py::init([](py::function generator_function, const std::shared_ptr schema) { - auto gen = std::make_shared(generator_function, schema); + .def( + py::init([](py::function generator_function, const std::vector &column_names, + const std::vector &column_types, int64_t dataset_len, py::handle sampler) { + auto gen = std::make_shared(generator_function, column_names, column_types, + dataset_len, toSamplerObj(sampler)); + THROW_IF_ERROR(gen->ValidateParams()); + return gen; + })) + .def(py::init([](py::function generator_function, const std::shared_ptr schema, + int64_t dataset_len, py::handle sampler) { + auto gen = + std::make_shared(generator_function, schema, dataset_len, toSamplerObj(sampler)); THROW_IF_ERROR(gen->ValidateParams()); return gen; - })) - .def("SetGeneratorDatasetSize", [](std::shared_ptr self, int64_t sz) { - self->SetGeneratorDatasetSize(sz); - return self; - }); + })); })); PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { diff --git a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc index b15a0bc1f0..a979035cb3 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc @@ -143,6 +143,9 @@ std::vector> toDatasetNode(std::shared_ptr toSamplerObj(py::handle py_sampler, bool isMindDataset) { + if (py_sampler.is_none()) { + return nullptr; + } if (py_sampler) { std::shared_ptr sampler_obj; if (!isMindDataset) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index e70213a03e..90b3a5939e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -43,25 +43,22 @@ Status GeneratorOp::Builder::SanityCheck() { Status GeneratorOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); *ptr = std::make_shared(build_generator_function_, build_column_names_, build_column_types_, - build_prefetch_size_, build_buffer_size_, build_op_connector_size_); + build_prefetch_size_, build_buffer_size_, build_op_connector_size_, nullptr); return (*ptr)->Init(); } GeneratorOp::GeneratorOp(py::function generator_function, std::vector column_names, std::vector column_types, int32_t prefetch_size, int32_t buffer_size, - int32_t connector_size, int64_t pre_counter_size) - : PipelineOp(connector_size), + int32_t connector_size, std::shared_ptr sampler) + : PipelineOp(connector_size, std::move(sampler)), generator_function_(generator_function), column_names_(column_names), column_types_(column_types), prefetch_size_(prefetch_size), buffer_size_(buffer_size), - pre_counter_size_(pre_counter_size), buffer_id_(0), generator_counter_(0) {} -GeneratorOp::~GeneratorOp() { this->Dealloc(); } - void GeneratorOp::Print(std::ostream &out, bool show_all) const { if (!show_all) { // Call the super class for displaying any common 1-liner info @@ -79,32 +76,32 @@ void GeneratorOp::Print(std::ostream &out, bool show_all) const { out << "\n\n"; } } - -void GeneratorOp::Dealloc() noexcept { - // Setup GIL state - PyGILState_STATE gstate; - gstate = PyGILState_Ensure(); - // GC the generator object within GIL - if (generator_function_.ref_count() == 1) generator_function_.dec_ref(); - if (generator_.ref_count() == 1) (void)generator_.dec_ref(); - // Release GIL - PyGILState_Release(gstate); +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status GeneratorOp::InitSampler() { + if (sampler_ != nullptr) return sampler_->HandshakeRandomAccessOp(this); + return Status::OK(); } -// Reentrant init method. -Status GeneratorOp::Init() { - // Reset BufferID - buffer_id_ = 0; - Status ret; +// Invoke the generatorFunction to get generator object +Status GeneratorOp::CreateGeneratorObject() { + Status ret = Status::OK(); { // Acquire Python GIL py::gil_scoped_acquire gil_acquire; if (Py_IsInitialized() == 0) { return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); } - // Invoke the generatorFunction to get generator object try { - generator_ = generator_function_(); + py::array sample_ids; + if (sampler_ != nullptr) { + // Sampler is not null which means the source is RandomAccessible + // get all samples and pass it to the Generator function + RETURN_IF_NOT_OK(sampler_->GetAllIdsThenReset(&sample_ids)); + // If sampler is a user-defined python sampler, sample_ids will flow from python to c++ and back to python + generator_ = generator_function_(sample_ids); + } else { + generator_ = generator_function_(); + } } catch (const py::error_already_set &e) { ret = Status(StatusCode::kPyFuncException, e.what()); } @@ -112,6 +109,13 @@ Status GeneratorOp::Init() { return ret; } +// Reentrant init method. +Status GeneratorOp::Init() { + buffer_id_ = 0; + RETURN_IF_NOT_OK(InitSampler()); + return CreateGeneratorObject(); +} + Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) { if (!py::isinstance(py_data)) { return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, @@ -191,6 +195,9 @@ Status GeneratorOp::operator()() { TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); std::unique_ptr fetched_buffer; + RETURN_IF_NOT_OK(Init()); + + int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_; bool eof = false; while (!eof) { // Create new buffer each iteration @@ -212,10 +219,10 @@ Status GeneratorOp::operator()() { if (!eoe) { return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); } - if (pre_counter_size_ != -1 && pre_counter_size_ != generator_counter_) { + if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_) { std::stringstream ss; ss << "The actual amount of data read from generator " << generator_counter_ - << " is different from generator.len " << pre_counter_size_ + << " is different from generator.len " << num_rows_sampled << ", you should adjust generator.len to make them match."; return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str()); } @@ -259,7 +266,10 @@ Status GeneratorOp::operator()() { Status GeneratorOp::Reset() { // Reset Op state MS_LOG(DEBUG) << Name() << " performing a self-reset."; - RETURN_IF_NOT_OK(this->Init()); + // Reset BufferID + buffer_id_ = 0; + // Create new generator object + RETURN_IF_NOT_OK(CreateGeneratorObject()); if (this->op_total_repeats() < 0) { // Wake up master thread wp_.Set(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h index a4d12cf9c2..1c9a7f9a95 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -26,6 +26,7 @@ #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/util/wait_post.h" #include "pybind11/pybind11.h" @@ -35,47 +36,47 @@ namespace mindspore { namespace dataset { #pragma GCC visibility push(hidden) -class GeneratorOp : public PipelineOp { +class GeneratorOp : public PipelineOp, public RandomAccessOp { public: class Builder { public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. + /// Builder constructor. Creates the builder object. + /// \note No default args + /// \return This is a constructor. Builder(); ~Builder() = default; - // Setter method. - // @return Builder setter method returns reference to the builder. + /// Setter method. + /// \return Builder setter method returns reference to the builder. Builder &SetGeneratorFunction(py::function generator_function) { build_generator_function_ = generator_function; return *this; } - // Setter method. - // @return Builder setter method returns reference to the builder. + /// Setter method. + /// \return Builder setter method returns reference to the builder. Builder &SetColumnNames(const std::vector &column_names) { build_column_names_ = column_names; return *this; } - // Setter method. - // @return Builder setter method returns reference to the builder. + /// Setter method. + /// \return Builder setter method returns reference to the builder. Builder &SetColumnTypes(const std::vector &column_types) { build_column_types_ = column_types; return *this; } - // Setter method. - // @return Builder setter method returns reference to the builder. + /// Setter method. + /// \return Builder setter method returns reference to the builder. Builder &SetPrefetchSize(int32_t prefetch_size) { build_prefetch_size_ = prefetch_size; return *this; } - // The builder "build" method creates the final object. - // @return shared_ptr to the new GeneratorOp object + /// The builder "build" method creates the final object. + /// \return shared_ptr to the new GeneratorOp object Status Build(std::shared_ptr *); private: @@ -94,56 +95,53 @@ class GeneratorOp : public PipelineOp { GeneratorOp(py::function generator_function, std::vector column_names, std::vector column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size, - int64_t pre_counter_size = 0); + std::shared_ptr sampler); - ~GeneratorOp(); + ~GeneratorOp() = default; - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary + /// A print method typically used for debugging + /// \param out - The output stream to write output to + /// \param show_all - A bool to control if you want to show all info or just a summary void Print(std::ostream &out, bool show_all) const override; - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param generator_op - reference to the GeneratorOp to display - // @return - the output stream must be returned + /// << Stream output operator overload + /// \notes This allows you to write the debug print info using stream operators + /// \param out - reference to the output stream being overloaded + /// \param generator_op - reference to the GeneratorOp to display + /// \return - the output stream must be returned friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) { generator_op.Print(out, false); return out; } - // Class functor operator () override. - // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work. - // @return Status The status code returned + /// Class functor operator () override. + /// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will + /// provide the master loop that drives the logic for performing the work. + /// \return Status The status code returned Status operator()() override; - // Overrides base class reset method. When an operator does a reset, it cleans up any state - // info from it's previous execution and then initializes itself so that it can be executed - // again. - // @return Status The status code returned + /// Overrides base class reset method. When an operator does a reset, it cleans up any state + /// info from it's previous execution and then initializes itself so that it can be executed + /// again. + /// \return Status The status code returned Status Reset() override; - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. + /// Base-class override for NodePass visitor acceptor. + /// \param p - Pointer to the NodePass to be accepted. + /// \param modified - Whether this node visit modified the pipeline. + /// \return - Status of the node visit. Status Accept(NodePass *p, bool *const modified) override; - // Op name getter - // @return Name of the current Op + /// Op name getter + /// \return Name of the current Op std::string Name() const override { return "GeneratorOp"; } - Status Init(); - private: py::function generator_function_; std::vector column_names_; std::vector column_types_; int32_t prefetch_size_; int32_t buffer_size_; - int64_t pre_counter_size_; int64_t generator_counter_; py::object generator_; @@ -151,15 +149,25 @@ class GeneratorOp : public PipelineOp { WaitPost wp_; - void Dealloc() noexcept; - Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); Status FillBuffer(TensorQTable *tt); - // Private function for computing the assignment of the column name map. - // @return - Status + /// Private function for computing the assignment of the column name map. + /// \return - Status Status ComputeColMap() override; + + /// Initialize Sampler, calls sampler->Init() within + /// \return Status The status code returned + Status InitSampler(); + + /// Create new Generator object from the generator function + /// \return Status The status code returned + Status CreateGeneratorObject(); + + /// Initialize GeneratorOp + /// \return Status The status code returned + Status Init(); }; #pragma GCC visibility pop diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 59b7c7c94b..c8f478c818 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -177,8 +177,21 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { child_num_rows = child_[0]->CalculateNumSamples(num_rows); } int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; - int64_t num_per_shard = std::ceil(child_num_rows * 1.0 / num_devices_); - return std::min(num_samples, num_per_shard); + int64_t remainder = (child_num_rows + offset_) % num_devices_; + int64_t shard_size = (child_num_rows + offset_) / num_devices_; + if (offset_ != -1 || !even_dist_) { + if (offset_ == -1) offset_ = 0; + if (device_id_ < remainder) shard_size++; + if (device_id_ < offset_) shard_size--; + } else { + offset_ = 0; + shard_size = (child_num_rows + num_devices_ - 1) / num_devices_; + } + // add 1 to an empty shard + // this logic is needed to follow the logic in initSampler that is written for ConcatDataset + if (shard_size == 0) shard_size++; + + return std::min(num_samples, shard_size); } void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h index 4a9eee157e..1f43fc9df4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -48,6 +48,10 @@ class RandomAccessOp { // default destructor virtual ~RandomAccessOp() = default; + /// Set num_rows + /// \param num_rows + void SetNumRows(int64_t num_rows) { num_rows_ = num_rows; } + protected: // The amount of rows in the dataset itself. This is the before-sampling value, the // total count of rows. A sampler may choose to sample less than this amount. diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index 0209fd1ec7..318a496791 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -95,19 +95,19 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr &size // calculate the size of the shard int64_t shard_dataset_size = 0; - if (sampler_ != nullptr) { - std::shared_ptr sampler_rt = - std::static_pointer_cast(sampler_->SamplerBuild()); + std::shared_ptr sampler_rt = + sampler_ ? std::dynamic_pointer_cast(sampler_->SamplerBuild()) : nullptr; + if (sampler_rt != nullptr) { sampler_rt->SetNumRowsInDataset(total_dataset_size); sampler_rt->InitSampler(); // (total_size % num_shards != 0) & shard_id >= (remainder) ? CalculateNumSamples()-1 : CalculateNumSamples() // example: 23 rows, 10 shards --> shard sizes = {3,3,3,2,2,2,2,2,2,2} - if ((sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum()) > 0 && + if ((sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum()) >= 0 && sampler_rt->GetDeviceID() >= (sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum())) { - shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples()) - 1; + shard_dataset_size = sampler_rt->GetNumSamples() / sampler_rt->GetDeviceNum(); } else { - shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples()); + shard_dataset_size = sampler_rt->GetNumSamples() / sampler_rt->GetDeviceNum() + 1; } } else { shard_dataset_size = total_dataset_size; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index 49635e7345..09b9830f99 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include "minddata/dataset/engine/datasetops/repeat_op.h" @@ -29,24 +30,27 @@ namespace mindspore { namespace dataset { GeneratorNode::GeneratorNode(py::function generator_function, const std::vector &column_names, - const std::vector &column_types) + const std::vector &column_types, int64_t source_len, + std::shared_ptr sampler) : MappableSourceNode(), generator_function_(generator_function), column_names_(column_names), column_types_(column_types), - reset_ancestor_(nullptr) {} + reset_ancestor_(nullptr), + sampler_(std::move(sampler)), + source_len_(source_len) {} -GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr &schema) +GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr &schema, + int64_t source_len, std::shared_ptr sampler) : MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {} std::shared_ptr GeneratorNode::Copy() { std::shared_ptr node; if (schema_ == nullptr) { - node = std::make_shared(generator_function_, column_names_, column_types_); + node = std::make_shared(generator_function_, column_names_, column_types_, source_len_, sampler_); } else { - node = std::make_shared(generator_function_, schema_); + node = std::make_shared(generator_function_, schema_, source_len_, sampler_); } - node->SetGeneratorDatasetSize(dataset_size_); return node; } @@ -69,17 +73,14 @@ Status GeneratorNode::Build(std::vector> *const node_ column_types_.push_back((col.type())); } } + std::shared_ptr sampler_rt = sampler_ ? sampler_->SamplerBuild() : nullptr; // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by // GeneratorOp internally. Here it is given a zero which is the default in generator builder std::shared_ptr op = std::make_shared(generator_function_, column_names_, column_types_, 0, - rows_per_buffer_, connector_que_size_, dataset_size_); - - // Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init - // needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private). - // This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which - // best be delivered when the test cases for this api is ready. - RETURN_IF_NOT_OK(op->Init()); + rows_per_buffer_, connector_que_size_, sampler_rt); + // set the number of rows from source length + op->SetNumRows(source_len_); // Add this GeneratorOp to its RepeatOp/EpochCtrlOp ancestor's EOE list. // When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list. @@ -118,5 +119,26 @@ Status GeneratorNode::AcceptAfter(IRNodePass *p, bool *const modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } + +Status GeneratorNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + if (source_len_ == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size)); + dataset_size_ = *dataset_size; + return Status::OK(); + } else { + int64_t sample_size; + int64_t num_rows; + num_rows = source_len_; + sample_size = sampler_ ? sampler_->SamplerBuild()->CalculateNumSamples(num_rows) : num_rows; + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h index 5321419cbf..7578c7302f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h @@ -34,10 +34,11 @@ class GeneratorNode : public MappableSourceNode { public: /// \brief Constructor GeneratorNode(py::function generator_function, const std::vector &column_names, - const std::vector &column_types); + const std::vector &column_types, int64_t source_len, std::shared_ptr sampler); /// \brief Constructor - GeneratorNode(py::function generator_function, const std::shared_ptr &schema); + GeneratorNode(py::function generator_function, const std::shared_ptr &schema, int64_t source_len, + std::shared_ptr sampler); /// \brief Destructor ~GeneratorNode() = default; @@ -67,11 +68,6 @@ class GeneratorNode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; - /// \brief Setter for DatasetSize in GeneratorNode - /// \param[in] sz dataset size to set - /// \return void - void SetGeneratorDatasetSize(int64_t sz) { dataset_size_ = sz; } - bool IsSizeDefined() override { return false; } /// \brief Record the vector of Repeat/EpochCtrl nodes that are ancestors of this node @@ -82,6 +78,14 @@ class GeneratorNode : public MappableSourceNode { reset_ancestor_ = src; return Status::OK(); } + /// Returns the dataset size of GeneratorOp. If is mappable (sampler isn not null), the sampler is used. + /// Otherwise, a dry run is needed. + /// \param[in] size_getter TreeConsumer to be used for a dryrun + /// \param[in] estimate + /// \param[out] dataset_size + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; /// \brief Getter functions const py::function &GeneratorFunction() const { return generator_function_; } @@ -102,6 +106,8 @@ class GeneratorNode : public MappableSourceNode { std::vector column_types_; std::shared_ptr schema_; std::shared_ptr reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass + std::shared_ptr sampler_; + int64_t source_len_; // Length of the dataset source provided by the user, -1 means it's unknown /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 9d08e655fe..4562253d18 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3427,23 +3427,31 @@ def _py_sampler_fn(sampler, num_samples, dataset): yield tuple([np.array(x, copy=False) for x in val]) -def _cpp_sampler_fn(sampler, dataset): +def _cpp_sampler_fn(sample_ids, dataset): """ Generator function wrapper for mappable dataset with cpp sampler. """ - indices = sampler.get_indices() - for i in indices: + if not isinstance(sample_ids, np.ndarray): + raise RuntimeError("Sample IDs are not in a numpy array.") + if sample_ids.size == 0: + raise RuntimeError("Sampler passed an empty sample IDs list.") + + for i in sample_ids: val = dataset[i] # convert output tensors to ndarrays yield tuple([np.array(x, copy=False) for x in val]) -def _cpp_sampler_fn_mp(sampler, sample_fn): +def _cpp_sampler_fn_mp(sample_ids, sample_fn): """ Multiprocessing generator function wrapper for mappable dataset with cpp sampler. """ - indices = sampler.get_indices() - return sample_fn.process(indices) + if not isinstance(sample_ids, np.ndarray): + raise RuntimeError("Sample IDs are not in a numpy array.") + if sample_ids.size == 0: + raise RuntimeError("Sampler passed an empty sample IDs list.") + + return sample_fn.process(sample_ids) def _py_sampler_fn_mp(sampler, num_samples, sample_fn): @@ -3811,18 +3819,9 @@ class GeneratorDataset(MappableDataset): self.schema = Schema(schema) # Move get dataset_size by len from parse to here, because self.source will # lose attribution of '__len__' after deepcopy. - self.dataset_size = None + self.source_len = -1 # unknown if hasattr(self.source, "__len__"): - if not self.num_shards: - self.dataset_size = len(self.source) - else: - self.dataset_size = math.ceil(len(self.source) / self.num_shards) - - rows_from_sampler = self._get_sampler_dataset_size() - if self.num_samples is not None and self.num_samples < rows_from_sampler: - rows_from_sampler = self.num_samples - if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: - self.dataset_size = rows_from_sampler + self.source_len = len(self.source) def __deepcopy__(self, memodict): if id(self) in memodict: @@ -3839,6 +3838,7 @@ class GeneratorDataset(MappableDataset): new_op.num_samples = copy.deepcopy(self.num_samples, memodict) new_op.sampler = copy.deepcopy(self.sampler) new_op.dataset_size = self.dataset_size + new_op.source_len = self.source_len new_op.saved_output_types = self.saved_output_types new_op.saved_output_shapes = self.saved_output_shapes if hasattr(self, "__total_batch__"): @@ -3847,22 +3847,28 @@ class GeneratorDataset(MappableDataset): if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, samplers.WeightedRandomSampler, samplers.Sampler)): - sampler_instance = new_op.sampler.create() - sampler_instance.set_num_rows(len(self.source)) - sampler_instance.initialize() if new_op.num_parallel_workers > 1: sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) - new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn)) + new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) else: - new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) + new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) else: + # the sampler provided is not a built-in sampler, it is a list of sample_ids + new_op.sample_ids = new_op.sampler + # since list of sample_ids are not passed to c++, we need to find the proper len here + new_op.source_len = min(self.source_len, len(new_op.sample_ids)) if self.source_len != -1 else len( + new_op.sample_ids) + new_op.source_len = min(self.source_len, + new_op.num_samples) if new_op.num_samples is not None else new_op.source_len + new_op.sampler = None if new_op.num_parallel_workers > 1: sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) - new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn)) + new_op.source = (lambda: _py_sampler_fn_mp(new_op.sample_ids, new_op.num_samples, sample_fn)) else: - new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) + new_op.source = (lambda: _py_sampler_fn(new_op.sample_ids, new_op.num_samples, self.source)) else: try: + new_op.sampler = None iter(self.source) except TypeError: # Use generator function if input callable @@ -3881,16 +3887,13 @@ class GeneratorDataset(MappableDataset): return self.sampler.is_sharded() def parse(self, children=None): - if self.dataset_size is None: - self.dataset_size = -1 if self.schema is None: - return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize( - self.dataset_size) \ - .SetNumWorkers(self.num_parallel_workers) + return cde.GeneratorNode(self.source, self.column_names, self.column_types, + self.source_len, self.sampler).SetNumWorkers(self.num_parallel_workers) schema = self.schema if isinstance(schema, Schema): schema = self.schema.cpp_schema - return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(self.dataset_size).SetNumWorkers( + return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler).SetNumWorkers( self.num_parallel_workers)