| @@ -184,21 +184,21 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) { | |||||
| PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { | PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { | ||||
| (void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>( | (void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>( | ||||
| *m, "GeneratorNode", "to create a GeneratorNode") | *m, "GeneratorNode", "to create a GeneratorNode") | ||||
| .def(py::init([](py::function generator_function, const std::vector<std::string> &column_names, | |||||
| const std::vector<DataType> &column_types) { | |||||
| auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types); | |||||
| THROW_IF_ERROR(gen->ValidateParams()); | |||||
| return gen; | |||||
| })) | |||||
| .def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema) { | |||||
| auto gen = std::make_shared<GeneratorNode>(generator_function, schema); | |||||
| .def( | |||||
| py::init([](py::function generator_function, const std::vector<std::string> &column_names, | |||||
| const std::vector<DataType> &column_types, int64_t dataset_len, py::handle sampler) { | |||||
| auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types, | |||||
| dataset_len, toSamplerObj(sampler)); | |||||
| THROW_IF_ERROR(gen->ValidateParams()); | |||||
| return gen; | |||||
| })) | |||||
| .def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema, | |||||
| int64_t dataset_len, py::handle sampler) { | |||||
| auto gen = | |||||
| std::make_shared<GeneratorNode>(generator_function, schema, dataset_len, toSamplerObj(sampler)); | |||||
| THROW_IF_ERROR(gen->ValidateParams()); | THROW_IF_ERROR(gen->ValidateParams()); | ||||
| return gen; | return gen; | ||||
| })) | |||||
| .def("SetGeneratorDatasetSize", [](std::shared_ptr<GeneratorNode> self, int64_t sz) { | |||||
| self->SetGeneratorDatasetSize(sz); | |||||
| return self; | |||||
| }); | |||||
| })); | |||||
| })); | })); | ||||
| PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { | PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { | ||||
| @@ -143,6 +143,9 @@ std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetN | |||||
| } | } | ||||
| std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDataset) { | std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDataset) { | ||||
| if (py_sampler.is_none()) { | |||||
| return nullptr; | |||||
| } | |||||
| if (py_sampler) { | if (py_sampler) { | ||||
| std::shared_ptr<SamplerObj> sampler_obj; | std::shared_ptr<SamplerObj> sampler_obj; | ||||
| if (!isMindDataset) { | if (!isMindDataset) { | ||||
| @@ -43,25 +43,22 @@ Status GeneratorOp::Builder::SanityCheck() { | |||||
| Status GeneratorOp::Builder::Build(std::shared_ptr<GeneratorOp> *ptr) { | Status GeneratorOp::Builder::Build(std::shared_ptr<GeneratorOp> *ptr) { | ||||
| RETURN_IF_NOT_OK(SanityCheck()); | RETURN_IF_NOT_OK(SanityCheck()); | ||||
| *ptr = std::make_shared<GeneratorOp>(build_generator_function_, build_column_names_, build_column_types_, | *ptr = std::make_shared<GeneratorOp>(build_generator_function_, build_column_names_, build_column_types_, | ||||
| build_prefetch_size_, build_buffer_size_, build_op_connector_size_); | |||||
| build_prefetch_size_, build_buffer_size_, build_op_connector_size_, nullptr); | |||||
| return (*ptr)->Init(); | return (*ptr)->Init(); | ||||
| } | } | ||||
| GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names, | GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names, | ||||
| std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, | std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, | ||||
| int32_t connector_size, int64_t pre_counter_size) | |||||
| : PipelineOp(connector_size), | |||||
| int32_t connector_size, std::shared_ptr<SamplerRT> sampler) | |||||
| : PipelineOp(connector_size, std::move(sampler)), | |||||
| generator_function_(generator_function), | generator_function_(generator_function), | ||||
| column_names_(column_names), | column_names_(column_names), | ||||
| column_types_(column_types), | column_types_(column_types), | ||||
| prefetch_size_(prefetch_size), | prefetch_size_(prefetch_size), | ||||
| buffer_size_(buffer_size), | buffer_size_(buffer_size), | ||||
| pre_counter_size_(pre_counter_size), | |||||
| buffer_id_(0), | buffer_id_(0), | ||||
| generator_counter_(0) {} | generator_counter_(0) {} | ||||
| GeneratorOp::~GeneratorOp() { this->Dealloc(); } | |||||
| void GeneratorOp::Print(std::ostream &out, bool show_all) const { | void GeneratorOp::Print(std::ostream &out, bool show_all) const { | ||||
| if (!show_all) { | if (!show_all) { | ||||
| // Call the super class for displaying any common 1-liner info | // Call the super class for displaying any common 1-liner info | ||||
| @@ -79,32 +76,32 @@ void GeneratorOp::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\n\n"; | out << "\n\n"; | ||||
| } | } | ||||
| } | } | ||||
| void GeneratorOp::Dealloc() noexcept { | |||||
| // Setup GIL state | |||||
| PyGILState_STATE gstate; | |||||
| gstate = PyGILState_Ensure(); | |||||
| // GC the generator object within GIL | |||||
| if (generator_function_.ref_count() == 1) generator_function_.dec_ref(); | |||||
| if (generator_.ref_count() == 1) (void)generator_.dec_ref(); | |||||
| // Release GIL | |||||
| PyGILState_Release(gstate); | |||||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | |||||
| Status GeneratorOp::InitSampler() { | |||||
| if (sampler_ != nullptr) return sampler_->HandshakeRandomAccessOp(this); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Reentrant init method. | |||||
| Status GeneratorOp::Init() { | |||||
| // Reset BufferID | |||||
| buffer_id_ = 0; | |||||
| Status ret; | |||||
| // Invoke the generatorFunction to get generator object | |||||
| Status GeneratorOp::CreateGeneratorObject() { | |||||
| Status ret = Status::OK(); | |||||
| { | { | ||||
| // Acquire Python GIL | // Acquire Python GIL | ||||
| py::gil_scoped_acquire gil_acquire; | py::gil_scoped_acquire gil_acquire; | ||||
| if (Py_IsInitialized() == 0) { | if (Py_IsInitialized() == 0) { | ||||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | ||||
| } | } | ||||
| // Invoke the generatorFunction to get generator object | |||||
| try { | try { | ||||
| generator_ = generator_function_(); | |||||
| py::array sample_ids; | |||||
| if (sampler_ != nullptr) { | |||||
| // Sampler is not null which means the source is RandomAccessible | |||||
| // get all samples and pass it to the Generator function | |||||
| RETURN_IF_NOT_OK(sampler_->GetAllIdsThenReset(&sample_ids)); | |||||
| // If sampler is a user-defined python sampler, sample_ids will flow from python to c++ and back to python | |||||
| generator_ = generator_function_(sample_ids); | |||||
| } else { | |||||
| generator_ = generator_function_(); | |||||
| } | |||||
| } catch (const py::error_already_set &e) { | } catch (const py::error_already_set &e) { | ||||
| ret = Status(StatusCode::kPyFuncException, e.what()); | ret = Status(StatusCode::kPyFuncException, e.what()); | ||||
| } | } | ||||
| @@ -112,6 +109,13 @@ Status GeneratorOp::Init() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| // Reentrant init method. | |||||
| Status GeneratorOp::Init() { | |||||
| buffer_id_ = 0; | |||||
| RETURN_IF_NOT_OK(InitSampler()); | |||||
| return CreateGeneratorObject(); | |||||
| } | |||||
| Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) { | Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) { | ||||
| if (!py::isinstance<py::tuple>(py_data)) { | if (!py::isinstance<py::tuple>(py_data)) { | ||||
| return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, | return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, | ||||
| @@ -191,6 +195,9 @@ Status GeneratorOp::operator()() { | |||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | ||||
| std::unique_ptr<DataBuffer> fetched_buffer; | std::unique_ptr<DataBuffer> fetched_buffer; | ||||
| RETURN_IF_NOT_OK(Init()); | |||||
| int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_; | |||||
| bool eof = false; | bool eof = false; | ||||
| while (!eof) { | while (!eof) { | ||||
| // Create new buffer each iteration | // Create new buffer each iteration | ||||
| @@ -212,10 +219,10 @@ Status GeneratorOp::operator()() { | |||||
| if (!eoe) { | if (!eoe) { | ||||
| return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); | return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); | ||||
| } | } | ||||
| if (pre_counter_size_ != -1 && pre_counter_size_ != generator_counter_) { | |||||
| if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_) { | |||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << "The actual amount of data read from generator " << generator_counter_ | ss << "The actual amount of data read from generator " << generator_counter_ | ||||
| << " is different from generator.len " << pre_counter_size_ | |||||
| << " is different from generator.len " << num_rows_sampled | |||||
| << ", you should adjust generator.len to make them match."; | << ", you should adjust generator.len to make them match."; | ||||
| return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str()); | return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str()); | ||||
| } | } | ||||
| @@ -259,7 +266,10 @@ Status GeneratorOp::operator()() { | |||||
| Status GeneratorOp::Reset() { | Status GeneratorOp::Reset() { | ||||
| // Reset Op state | // Reset Op state | ||||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | MS_LOG(DEBUG) << Name() << " performing a self-reset."; | ||||
| RETURN_IF_NOT_OK(this->Init()); | |||||
| // Reset BufferID | |||||
| buffer_id_ = 0; | |||||
| // Create new generator object | |||||
| RETURN_IF_NOT_OK(CreateGeneratorObject()); | |||||
| if (this->op_total_repeats() < 0) { | if (this->op_total_repeats() < 0) { | ||||
| // Wake up master thread | // Wake up master thread | ||||
| wp_.Set(); | wp_.Set(); | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/pipeline_op.h" | #include "minddata/dataset/engine/datasetops/pipeline_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||||
| #include "minddata/dataset/util/wait_post.h" | #include "minddata/dataset/util/wait_post.h" | ||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| @@ -35,47 +36,47 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| #pragma GCC visibility push(hidden) | #pragma GCC visibility push(hidden) | ||||
| class GeneratorOp : public PipelineOp { | |||||
| class GeneratorOp : public PipelineOp, public RandomAccessOp { | |||||
| public: | public: | ||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| // Builder constructor. Creates the builder object. | |||||
| // @note No default args | |||||
| // @return This is a constructor. | |||||
| /// Builder constructor. Creates the builder object. | |||||
| /// \note No default args | |||||
| /// \return This is a constructor. | |||||
| Builder(); | Builder(); | ||||
| ~Builder() = default; | ~Builder() = default; | ||||
| // Setter method. | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| /// Setter method. | |||||
| /// \return Builder setter method returns reference to the builder. | |||||
| Builder &SetGeneratorFunction(py::function generator_function) { | Builder &SetGeneratorFunction(py::function generator_function) { | ||||
| build_generator_function_ = generator_function; | build_generator_function_ = generator_function; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method. | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| /// Setter method. | |||||
| /// \return Builder setter method returns reference to the builder. | |||||
| Builder &SetColumnNames(const std::vector<std::string> &column_names) { | Builder &SetColumnNames(const std::vector<std::string> &column_names) { | ||||
| build_column_names_ = column_names; | build_column_names_ = column_names; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method. | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| /// Setter method. | |||||
| /// \return Builder setter method returns reference to the builder. | |||||
| Builder &SetColumnTypes(const std::vector<DataType> &column_types) { | Builder &SetColumnTypes(const std::vector<DataType> &column_types) { | ||||
| build_column_types_ = column_types; | build_column_types_ = column_types; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method. | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| /// Setter method. | |||||
| /// \return Builder setter method returns reference to the builder. | |||||
| Builder &SetPrefetchSize(int32_t prefetch_size) { | Builder &SetPrefetchSize(int32_t prefetch_size) { | ||||
| build_prefetch_size_ = prefetch_size; | build_prefetch_size_ = prefetch_size; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| // The builder "build" method creates the final object. | |||||
| // @return shared_ptr to the new GeneratorOp object | |||||
| /// The builder "build" method creates the final object. | |||||
| /// \return shared_ptr to the new GeneratorOp object | |||||
| Status Build(std::shared_ptr<GeneratorOp> *); | Status Build(std::shared_ptr<GeneratorOp> *); | ||||
| private: | private: | ||||
| @@ -94,56 +95,53 @@ class GeneratorOp : public PipelineOp { | |||||
| GeneratorOp(py::function generator_function, std::vector<std::string> column_names, | GeneratorOp(py::function generator_function, std::vector<std::string> column_names, | ||||
| std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size, | std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size, | ||||
| int64_t pre_counter_size = 0); | |||||
| std::shared_ptr<SamplerRT> sampler); | |||||
| ~GeneratorOp(); | |||||
| ~GeneratorOp() = default; | |||||
| // A print method typically used for debugging | |||||
| // @param out - The output stream to write output to | |||||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||||
| /// A print method typically used for debugging | |||||
| /// \param out - The output stream to write output to | |||||
| /// \param show_all - A bool to control if you want to show all info or just a summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| // << Stream output operator overload | |||||
| // @notes This allows you to write the debug print info using stream operators | |||||
| // @param out - reference to the output stream being overloaded | |||||
| // @param generator_op - reference to the GeneratorOp to display | |||||
| // @return - the output stream must be returned | |||||
| /// << Stream output operator overload | |||||
| /// \notes This allows you to write the debug print info using stream operators | |||||
| /// \param out - reference to the output stream being overloaded | |||||
| /// \param generator_op - reference to the GeneratorOp to display | |||||
| /// \return - the output stream must be returned | |||||
| friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) { | friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) { | ||||
| generator_op.Print(out, false); | generator_op.Print(out, false); | ||||
| return out; | return out; | ||||
| } | } | ||||
| // Class functor operator () override. | |||||
| // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work. | |||||
| // @return Status The status code returned | |||||
| /// Class functor operator () override. | |||||
| /// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | |||||
| /// provide the master loop that drives the logic for performing the work. | |||||
| /// \return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Overrides base class reset method. When an operator does a reset, it cleans up any state | |||||
| // info from it's previous execution and then initializes itself so that it can be executed | |||||
| // again. | |||||
| // @return Status The status code returned | |||||
| /// Overrides base class reset method. When an operator does a reset, it cleans up any state | |||||
| /// info from it's previous execution and then initializes itself so that it can be executed | |||||
| /// again. | |||||
| /// \return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| /// Base-class override for NodePass visitor acceptor. | |||||
| /// \param p - Pointer to the NodePass to be accepted. | |||||
| /// \param modified - Whether this node visit modified the pipeline. | |||||
| /// \return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *const modified) override; | Status Accept(NodePass *p, bool *const modified) override; | ||||
| // Op name getter | |||||
| // @return Name of the current Op | |||||
| /// Op name getter | |||||
| /// \return Name of the current Op | |||||
| std::string Name() const override { return "GeneratorOp"; } | std::string Name() const override { return "GeneratorOp"; } | ||||
| Status Init(); | |||||
| private: | private: | ||||
| py::function generator_function_; | py::function generator_function_; | ||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||
| std::vector<DataType> column_types_; | std::vector<DataType> column_types_; | ||||
| int32_t prefetch_size_; | int32_t prefetch_size_; | ||||
| int32_t buffer_size_; | int32_t buffer_size_; | ||||
| int64_t pre_counter_size_; | |||||
| int64_t generator_counter_; | int64_t generator_counter_; | ||||
| py::object generator_; | py::object generator_; | ||||
| @@ -151,15 +149,25 @@ class GeneratorOp : public PipelineOp { | |||||
| WaitPost wp_; | WaitPost wp_; | ||||
| void Dealloc() noexcept; | |||||
| Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); | Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); | ||||
| Status FillBuffer(TensorQTable *tt); | Status FillBuffer(TensorQTable *tt); | ||||
| // Private function for computing the assignment of the column name map. | |||||
| // @return - Status | |||||
| /// Private function for computing the assignment of the column name map. | |||||
| /// \return - Status | |||||
| Status ComputeColMap() override; | Status ComputeColMap() override; | ||||
| /// Initialize Sampler, calls sampler->Init() within | |||||
| /// \return Status The status code returned | |||||
| Status InitSampler(); | |||||
| /// Create new Generator object from the generator function | |||||
| /// \return Status The status code returned | |||||
| Status CreateGeneratorObject(); | |||||
| /// Initialize GeneratorOp | |||||
| /// \return Status The status code returned | |||||
| Status Init(); | |||||
| }; | }; | ||||
| #pragma GCC visibility pop | #pragma GCC visibility pop | ||||
| @@ -177,8 +177,21 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| child_num_rows = child_[0]->CalculateNumSamples(num_rows); | child_num_rows = child_[0]->CalculateNumSamples(num_rows); | ||||
| } | } | ||||
| int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | ||||
| int64_t num_per_shard = std::ceil(child_num_rows * 1.0 / num_devices_); | |||||
| return std::min(num_samples, num_per_shard); | |||||
| int64_t remainder = (child_num_rows + offset_) % num_devices_; | |||||
| int64_t shard_size = (child_num_rows + offset_) / num_devices_; | |||||
| if (offset_ != -1 || !even_dist_) { | |||||
| if (offset_ == -1) offset_ = 0; | |||||
| if (device_id_ < remainder) shard_size++; | |||||
| if (device_id_ < offset_) shard_size--; | |||||
| } else { | |||||
| offset_ = 0; | |||||
| shard_size = (child_num_rows + num_devices_ - 1) / num_devices_; | |||||
| } | |||||
| // add 1 to an empty shard | |||||
| // this logic is needed to follow the logic in initSampler that is written for ConcatDataset | |||||
| if (shard_size == 0) shard_size++; | |||||
| return std::min(num_samples, shard_size); | |||||
| } | } | ||||
| void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | ||||
| @@ -48,6 +48,10 @@ class RandomAccessOp { | |||||
| // default destructor | // default destructor | ||||
| virtual ~RandomAccessOp() = default; | virtual ~RandomAccessOp() = default; | ||||
| /// Set num_rows | |||||
| /// \param num_rows | |||||
| void SetNumRows(int64_t num_rows) { num_rows_ = num_rows; } | |||||
| protected: | protected: | ||||
| // The amount of rows in the dataset itself. This is the before-sampling value, the | // The amount of rows in the dataset itself. This is the before-sampling value, the | ||||
| // total count of rows. A sampler may choose to sample less than this amount. | // total count of rows. A sampler may choose to sample less than this amount. | ||||
| @@ -95,19 +95,19 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||||
| // calculate the size of the shard | // calculate the size of the shard | ||||
| int64_t shard_dataset_size = 0; | int64_t shard_dataset_size = 0; | ||||
| if (sampler_ != nullptr) { | |||||
| std::shared_ptr<DistributedSamplerRT> sampler_rt = | |||||
| std::static_pointer_cast<DistributedSamplerRT>(sampler_->SamplerBuild()); | |||||
| std::shared_ptr<DistributedSamplerRT> sampler_rt = | |||||
| sampler_ ? std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_->SamplerBuild()) : nullptr; | |||||
| if (sampler_rt != nullptr) { | |||||
| sampler_rt->SetNumRowsInDataset(total_dataset_size); | sampler_rt->SetNumRowsInDataset(total_dataset_size); | ||||
| sampler_rt->InitSampler(); | sampler_rt->InitSampler(); | ||||
| // (total_size % num_shards != 0) & shard_id >= (remainder) ? CalculateNumSamples()-1 : CalculateNumSamples() | // (total_size % num_shards != 0) & shard_id >= (remainder) ? CalculateNumSamples()-1 : CalculateNumSamples() | ||||
| // example: 23 rows, 10 shards --> shard sizes = {3,3,3,2,2,2,2,2,2,2} | // example: 23 rows, 10 shards --> shard sizes = {3,3,3,2,2,2,2,2,2,2} | ||||
| if ((sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum()) > 0 && | |||||
| if ((sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum()) >= 0 && | |||||
| sampler_rt->GetDeviceID() >= (sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum())) { | sampler_rt->GetDeviceID() >= (sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum())) { | ||||
| shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples()) - 1; | |||||
| shard_dataset_size = sampler_rt->GetNumSamples() / sampler_rt->GetDeviceNum(); | |||||
| } else { | } else { | ||||
| shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples()); | |||||
| shard_dataset_size = sampler_rt->GetNumSamples() / sampler_rt->GetDeviceNum() + 1; | |||||
| } | } | ||||
| } else { | } else { | ||||
| shard_dataset_size = total_dataset_size; | shard_dataset_size = total_dataset_size; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | #include "minddata/dataset/engine/datasetops/repeat_op.h" | ||||
| @@ -29,24 +30,27 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | ||||
| const std::vector<DataType> &column_types) | |||||
| const std::vector<DataType> &column_types, int64_t source_len, | |||||
| std::shared_ptr<SamplerObj> sampler) | |||||
| : MappableSourceNode(), | : MappableSourceNode(), | ||||
| generator_function_(generator_function), | generator_function_(generator_function), | ||||
| column_names_(column_names), | column_names_(column_names), | ||||
| column_types_(column_types), | column_types_(column_types), | ||||
| reset_ancestor_(nullptr) {} | |||||
| reset_ancestor_(nullptr), | |||||
| sampler_(std::move(sampler)), | |||||
| source_len_(source_len) {} | |||||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | |||||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema, | |||||
| int64_t source_len, std::shared_ptr<SamplerObj> sampler) | |||||
| : MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {} | : MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {} | ||||
| std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | ||||
| std::shared_ptr<GeneratorNode> node; | std::shared_ptr<GeneratorNode> node; | ||||
| if (schema_ == nullptr) { | if (schema_ == nullptr) { | ||||
| node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_); | |||||
| node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_, source_len_, sampler_); | |||||
| } else { | } else { | ||||
| node = std::make_shared<GeneratorNode>(generator_function_, schema_); | |||||
| node = std::make_shared<GeneratorNode>(generator_function_, schema_, source_len_, sampler_); | |||||
| } | } | ||||
| node->SetGeneratorDatasetSize(dataset_size_); | |||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -69,17 +73,14 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ | |||||
| column_types_.push_back((col.type())); | column_types_.push_back((col.type())); | ||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<SamplerRT> sampler_rt = sampler_ ? sampler_->SamplerBuild() : nullptr; | |||||
| // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by | // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by | ||||
| // GeneratorOp internally. Here it is given a zero which is the default in generator builder | // GeneratorOp internally. Here it is given a zero which is the default in generator builder | ||||
| std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0, | std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0, | ||||
| rows_per_buffer_, connector_que_size_, dataset_size_); | |||||
| // Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init | |||||
| // needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private). | |||||
| // This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which | |||||
| // best be delivered when the test cases for this api is ready. | |||||
| RETURN_IF_NOT_OK(op->Init()); | |||||
| rows_per_buffer_, connector_que_size_, sampler_rt); | |||||
| // set the number of rows from source length | |||||
| op->SetNumRows(source_len_); | |||||
| // Add this GeneratorOp to its RepeatOp/EpochCtrlOp ancestor's EOE list. | // Add this GeneratorOp to its RepeatOp/EpochCtrlOp ancestor's EOE list. | ||||
| // When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list. | // When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list. | ||||
| @@ -118,5 +119,26 @@ Status GeneratorNode::AcceptAfter(IRNodePass *p, bool *const modified) { | |||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->VisitAfter(shared_from_base<GeneratorNode>(), modified); | return p->VisitAfter(shared_from_base<GeneratorNode>(), modified); | ||||
| } | } | ||||
| Status GeneratorNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||||
| int64_t *dataset_size) { | |||||
| if (dataset_size_ > 0) { | |||||
| *dataset_size = dataset_size_; | |||||
| return Status::OK(); | |||||
| } | |||||
| if (source_len_ == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size)); | |||||
| dataset_size_ = *dataset_size; | |||||
| return Status::OK(); | |||||
| } else { | |||||
| int64_t sample_size; | |||||
| int64_t num_rows; | |||||
| num_rows = source_len_; | |||||
| sample_size = sampler_ ? sampler_->SamplerBuild()->CalculateNumSamples(num_rows) : num_rows; | |||||
| *dataset_size = sample_size; | |||||
| dataset_size_ = *dataset_size; | |||||
| return Status::OK(); | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,10 +34,11 @@ class GeneratorNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | ||||
| const std::vector<DataType> &column_types); | |||||
| const std::vector<DataType> &column_types, int64_t source_len, std::shared_ptr<SamplerObj> sampler); | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema); | |||||
| GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema, int64_t source_len, | |||||
| std::shared_ptr<SamplerObj> sampler); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~GeneratorNode() = default; | ~GeneratorNode() = default; | ||||
| @@ -67,11 +68,6 @@ class GeneratorNode : public MappableSourceNode { | |||||
| /// \return Status Status::OK() if get shard id successfully | /// \return Status Status::OK() if get shard id successfully | ||||
| Status GetShardId(int32_t *shard_id) override; | Status GetShardId(int32_t *shard_id) override; | ||||
| /// \brief Setter for DatasetSize in GeneratorNode | |||||
| /// \param[in] sz dataset size to set | |||||
| /// \return void | |||||
| void SetGeneratorDatasetSize(int64_t sz) { dataset_size_ = sz; } | |||||
| bool IsSizeDefined() override { return false; } | bool IsSizeDefined() override { return false; } | ||||
| /// \brief Record the vector of Repeat/EpochCtrl nodes that are ancestors of this node | /// \brief Record the vector of Repeat/EpochCtrl nodes that are ancestors of this node | ||||
| @@ -82,6 +78,14 @@ class GeneratorNode : public MappableSourceNode { | |||||
| reset_ancestor_ = src; | reset_ancestor_ = src; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| /// Returns the dataset size of GeneratorOp. If is mappable (sampler isn not null), the sampler is used. | |||||
| /// Otherwise, a dry run is needed. | |||||
| /// \param[in] size_getter TreeConsumer to be used for a dryrun | |||||
| /// \param[in] estimate | |||||
| /// \param[out] dataset_size | |||||
| /// \return Status of the function | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||||
| int64_t *dataset_size) override; | |||||
| /// \brief Getter functions | /// \brief Getter functions | ||||
| const py::function &GeneratorFunction() const { return generator_function_; } | const py::function &GeneratorFunction() const { return generator_function_; } | ||||
| @@ -102,6 +106,8 @@ class GeneratorNode : public MappableSourceNode { | |||||
| std::vector<DataType> column_types_; | std::vector<DataType> column_types_; | ||||
| std::shared_ptr<SchemaObj> schema_; | std::shared_ptr<SchemaObj> schema_; | ||||
| std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass | std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass | ||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| int64_t source_len_; // Length of the dataset source provided by the user, -1 means it's unknown | |||||
| /// \brief Base-class override for accepting IRNodePass visitor | /// \brief Base-class override for accepting IRNodePass visitor | ||||
| /// \param[in] p The node to visit | /// \param[in] p The node to visit | ||||
| @@ -3427,23 +3427,31 @@ def _py_sampler_fn(sampler, num_samples, dataset): | |||||
| yield tuple([np.array(x, copy=False) for x in val]) | yield tuple([np.array(x, copy=False) for x in val]) | ||||
| def _cpp_sampler_fn(sampler, dataset): | |||||
| def _cpp_sampler_fn(sample_ids, dataset): | |||||
| """ | """ | ||||
| Generator function wrapper for mappable dataset with cpp sampler. | Generator function wrapper for mappable dataset with cpp sampler. | ||||
| """ | """ | ||||
| indices = sampler.get_indices() | |||||
| for i in indices: | |||||
| if not isinstance(sample_ids, np.ndarray): | |||||
| raise RuntimeError("Sample IDs are not in a numpy array.") | |||||
| if sample_ids.size == 0: | |||||
| raise RuntimeError("Sampler passed an empty sample IDs list.") | |||||
| for i in sample_ids: | |||||
| val = dataset[i] | val = dataset[i] | ||||
| # convert output tensors to ndarrays | # convert output tensors to ndarrays | ||||
| yield tuple([np.array(x, copy=False) for x in val]) | yield tuple([np.array(x, copy=False) for x in val]) | ||||
| def _cpp_sampler_fn_mp(sampler, sample_fn): | |||||
| def _cpp_sampler_fn_mp(sample_ids, sample_fn): | |||||
| """ | """ | ||||
| Multiprocessing generator function wrapper for mappable dataset with cpp sampler. | Multiprocessing generator function wrapper for mappable dataset with cpp sampler. | ||||
| """ | """ | ||||
| indices = sampler.get_indices() | |||||
| return sample_fn.process(indices) | |||||
| if not isinstance(sample_ids, np.ndarray): | |||||
| raise RuntimeError("Sample IDs are not in a numpy array.") | |||||
| if sample_ids.size == 0: | |||||
| raise RuntimeError("Sampler passed an empty sample IDs list.") | |||||
| return sample_fn.process(sample_ids) | |||||
| def _py_sampler_fn_mp(sampler, num_samples, sample_fn): | def _py_sampler_fn_mp(sampler, num_samples, sample_fn): | ||||
| @@ -3811,18 +3819,9 @@ class GeneratorDataset(MappableDataset): | |||||
| self.schema = Schema(schema) | self.schema = Schema(schema) | ||||
| # Move get dataset_size by len from parse to here, because self.source will | # Move get dataset_size by len from parse to here, because self.source will | ||||
| # lose attribution of '__len__' after deepcopy. | # lose attribution of '__len__' after deepcopy. | ||||
| self.dataset_size = None | |||||
| self.source_len = -1 # unknown | |||||
| if hasattr(self.source, "__len__"): | if hasattr(self.source, "__len__"): | ||||
| if not self.num_shards: | |||||
| self.dataset_size = len(self.source) | |||||
| else: | |||||
| self.dataset_size = math.ceil(len(self.source) / self.num_shards) | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| if self.num_samples is not None and self.num_samples < rows_from_sampler: | |||||
| rows_from_sampler = self.num_samples | |||||
| if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: | |||||
| self.dataset_size = rows_from_sampler | |||||
| self.source_len = len(self.source) | |||||
| def __deepcopy__(self, memodict): | def __deepcopy__(self, memodict): | ||||
| if id(self) in memodict: | if id(self) in memodict: | ||||
| @@ -3839,6 +3838,7 @@ class GeneratorDataset(MappableDataset): | |||||
| new_op.num_samples = copy.deepcopy(self.num_samples, memodict) | new_op.num_samples = copy.deepcopy(self.num_samples, memodict) | ||||
| new_op.sampler = copy.deepcopy(self.sampler) | new_op.sampler = copy.deepcopy(self.sampler) | ||||
| new_op.dataset_size = self.dataset_size | new_op.dataset_size = self.dataset_size | ||||
| new_op.source_len = self.source_len | |||||
| new_op.saved_output_types = self.saved_output_types | new_op.saved_output_types = self.saved_output_types | ||||
| new_op.saved_output_shapes = self.saved_output_shapes | new_op.saved_output_shapes = self.saved_output_shapes | ||||
| if hasattr(self, "__total_batch__"): | if hasattr(self, "__total_batch__"): | ||||
| @@ -3847,22 +3847,28 @@ class GeneratorDataset(MappableDataset): | |||||
| if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | ||||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | samplers.RandomSampler, samplers.SubsetRandomSampler, | ||||
| samplers.WeightedRandomSampler, samplers.Sampler)): | samplers.WeightedRandomSampler, samplers.Sampler)): | ||||
| sampler_instance = new_op.sampler.create() | |||||
| sampler_instance.set_num_rows(len(self.source)) | |||||
| sampler_instance.initialize() | |||||
| if new_op.num_parallel_workers > 1: | if new_op.num_parallel_workers > 1: | ||||
| sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) | sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) | ||||
| new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn)) | |||||
| new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) | |||||
| else: | else: | ||||
| new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) | |||||
| new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) | |||||
| else: | else: | ||||
| # the sampler provided is not a built-in sampler, it is a list of sample_ids | |||||
| new_op.sample_ids = new_op.sampler | |||||
| # since list of sample_ids are not passed to c++, we need to find the proper len here | |||||
| new_op.source_len = min(self.source_len, len(new_op.sample_ids)) if self.source_len != -1 else len( | |||||
| new_op.sample_ids) | |||||
| new_op.source_len = min(self.source_len, | |||||
| new_op.num_samples) if new_op.num_samples is not None else new_op.source_len | |||||
| new_op.sampler = None | |||||
| if new_op.num_parallel_workers > 1: | if new_op.num_parallel_workers > 1: | ||||
| sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) | sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) | ||||
| new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn)) | |||||
| new_op.source = (lambda: _py_sampler_fn_mp(new_op.sample_ids, new_op.num_samples, sample_fn)) | |||||
| else: | else: | ||||
| new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) | |||||
| new_op.source = (lambda: _py_sampler_fn(new_op.sample_ids, new_op.num_samples, self.source)) | |||||
| else: | else: | ||||
| try: | try: | ||||
| new_op.sampler = None | |||||
| iter(self.source) | iter(self.source) | ||||
| except TypeError: | except TypeError: | ||||
| # Use generator function if input callable | # Use generator function if input callable | ||||
| @@ -3881,16 +3887,13 @@ class GeneratorDataset(MappableDataset): | |||||
| return self.sampler.is_sharded() | return self.sampler.is_sharded() | ||||
| def parse(self, children=None): | def parse(self, children=None): | ||||
| if self.dataset_size is None: | |||||
| self.dataset_size = -1 | |||||
| if self.schema is None: | if self.schema is None: | ||||
| return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize( | |||||
| self.dataset_size) \ | |||||
| .SetNumWorkers(self.num_parallel_workers) | |||||
| return cde.GeneratorNode(self.source, self.column_names, self.column_types, | |||||
| self.source_len, self.sampler).SetNumWorkers(self.num_parallel_workers) | |||||
| schema = self.schema | schema = self.schema | ||||
| if isinstance(schema, Schema): | if isinstance(schema, Schema): | ||||
| schema = self.schema.cpp_schema | schema = self.schema.cpp_schema | ||||
| return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(self.dataset_size).SetNumWorkers( | |||||
| return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler).SetNumWorkers( | |||||
| self.num_parallel_workers) | self.num_parallel_workers) | ||||