Browse Source

Move GeneratorOp sampler to c++

tags/v1.2.0-rc1
hesham 4 years ago
parent
commit
c3278c983d
10 changed files with 211 additions and 142 deletions
  1. +13
    -13
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc
  2. +3
    -0
      mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc
  3. +36
    -26
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
  4. +53
    -45
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h
  5. +15
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  6. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h
  7. +6
    -6
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc
  8. +35
    -13
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc
  9. +13
    -7
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h
  10. +33
    -30
      mindspore/dataset/engine/datasets.py

+ 13
- 13
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc View File

@@ -184,21 +184,21 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>( (void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
*m, "GeneratorNode", "to create a GeneratorNode") *m, "GeneratorNode", "to create a GeneratorNode")
.def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types) {
auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types);
THROW_IF_ERROR(gen->ValidateParams());
return gen;
}))
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema) {
auto gen = std::make_shared<GeneratorNode>(generator_function, schema);
.def(
py::init([](py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types, int64_t dataset_len, py::handle sampler) {
auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types,
dataset_len, toSamplerObj(sampler));
THROW_IF_ERROR(gen->ValidateParams());
return gen;
}))
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema,
int64_t dataset_len, py::handle sampler) {
auto gen =
std::make_shared<GeneratorNode>(generator_function, schema, dataset_len, toSamplerObj(sampler));
THROW_IF_ERROR(gen->ValidateParams()); THROW_IF_ERROR(gen->ValidateParams());
return gen; return gen;
}))
.def("SetGeneratorDatasetSize", [](std::shared_ptr<GeneratorNode> self, int64_t sz) {
self->SetGeneratorDatasetSize(sz);
return self;
});
}));
})); }));


PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {


+ 3
- 0
mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc View File

@@ -143,6 +143,9 @@ std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetN
} }


std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDataset) { std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDataset) {
if (py_sampler.is_none()) {
return nullptr;
}
if (py_sampler) { if (py_sampler) {
std::shared_ptr<SamplerObj> sampler_obj; std::shared_ptr<SamplerObj> sampler_obj;
if (!isMindDataset) { if (!isMindDataset) {


+ 36
- 26
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc View File

@@ -43,25 +43,22 @@ Status GeneratorOp::Builder::SanityCheck() {
Status GeneratorOp::Builder::Build(std::shared_ptr<GeneratorOp> *ptr) { Status GeneratorOp::Builder::Build(std::shared_ptr<GeneratorOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck()); RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<GeneratorOp>(build_generator_function_, build_column_names_, build_column_types_, *ptr = std::make_shared<GeneratorOp>(build_generator_function_, build_column_names_, build_column_types_,
build_prefetch_size_, build_buffer_size_, build_op_connector_size_);
build_prefetch_size_, build_buffer_size_, build_op_connector_size_, nullptr);
return (*ptr)->Init(); return (*ptr)->Init();
} }


GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names, GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size,
int32_t connector_size, int64_t pre_counter_size)
: PipelineOp(connector_size),
int32_t connector_size, std::shared_ptr<SamplerRT> sampler)
: PipelineOp(connector_size, std::move(sampler)),
generator_function_(generator_function), generator_function_(generator_function),
column_names_(column_names), column_names_(column_names),
column_types_(column_types), column_types_(column_types),
prefetch_size_(prefetch_size), prefetch_size_(prefetch_size),
buffer_size_(buffer_size), buffer_size_(buffer_size),
pre_counter_size_(pre_counter_size),
buffer_id_(0), buffer_id_(0),
generator_counter_(0) {} generator_counter_(0) {}


GeneratorOp::~GeneratorOp() { this->Dealloc(); }

void GeneratorOp::Print(std::ostream &out, bool show_all) const { void GeneratorOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) { if (!show_all) {
// Call the super class for displaying any common 1-liner info // Call the super class for displaying any common 1-liner info
@@ -79,32 +76,32 @@ void GeneratorOp::Print(std::ostream &out, bool show_all) const {
out << "\n\n"; out << "\n\n";
} }
} }

void GeneratorOp::Dealloc() noexcept {
// Setup GIL state
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
// GC the generator object within GIL
if (generator_function_.ref_count() == 1) generator_function_.dec_ref();
if (generator_.ref_count() == 1) (void)generator_.dec_ref();
// Release GIL
PyGILState_Release(gstate);
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status GeneratorOp::InitSampler() {
if (sampler_ != nullptr) return sampler_->HandshakeRandomAccessOp(this);
return Status::OK();
} }


// Reentrant init method.
Status GeneratorOp::Init() {
// Reset BufferID
buffer_id_ = 0;
Status ret;
// Invoke the generatorFunction to get generator object
Status GeneratorOp::CreateGeneratorObject() {
Status ret = Status::OK();
{ {
// Acquire Python GIL // Acquire Python GIL
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) { if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
} }
// Invoke the generatorFunction to get generator object
try { try {
generator_ = generator_function_();
py::array sample_ids;
if (sampler_ != nullptr) {
// Sampler is not null which means the source is RandomAccessible
// get all samples and pass it to the Generator function
RETURN_IF_NOT_OK(sampler_->GetAllIdsThenReset(&sample_ids));
// If sampler is a user-defined python sampler, sample_ids will flow from python to c++ and back to python
generator_ = generator_function_(sample_ids);
} else {
generator_ = generator_function_();
}
} catch (const py::error_already_set &e) { } catch (const py::error_already_set &e) {
ret = Status(StatusCode::kPyFuncException, e.what()); ret = Status(StatusCode::kPyFuncException, e.what());
} }
@@ -112,6 +109,13 @@ Status GeneratorOp::Init() {
return ret; return ret;
} }


// Reentrant init method.
Status GeneratorOp::Init() {
buffer_id_ = 0;
RETURN_IF_NOT_OK(InitSampler());
return CreateGeneratorObject();
}

Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) { Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) {
if (!py::isinstance<py::tuple>(py_data)) { if (!py::isinstance<py::tuple>(py_data)) {
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, return Status(StatusCode::kPyFuncException, __LINE__, __FILE__,
@@ -191,6 +195,9 @@ Status GeneratorOp::operator()() {
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks()));
std::unique_ptr<DataBuffer> fetched_buffer; std::unique_ptr<DataBuffer> fetched_buffer;
RETURN_IF_NOT_OK(Init());

int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_;
bool eof = false; bool eof = false;
while (!eof) { while (!eof) {
// Create new buffer each iteration // Create new buffer each iteration
@@ -212,10 +219,10 @@ Status GeneratorOp::operator()() {
if (!eoe) { if (!eoe) {
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what());
} }
if (pre_counter_size_ != -1 && pre_counter_size_ != generator_counter_) {
if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_) {
std::stringstream ss; std::stringstream ss;
ss << "The actual amount of data read from generator " << generator_counter_ ss << "The actual amount of data read from generator " << generator_counter_
<< " is different from generator.len " << pre_counter_size_
<< " is different from generator.len " << num_rows_sampled
<< ", you should adjust generator.len to make them match."; << ", you should adjust generator.len to make them match.";
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str()); return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str());
} }
@@ -259,7 +266,10 @@ Status GeneratorOp::operator()() {
Status GeneratorOp::Reset() { Status GeneratorOp::Reset() {
// Reset Op state // Reset Op state
MS_LOG(DEBUG) << Name() << " performing a self-reset."; MS_LOG(DEBUG) << Name() << " performing a self-reset.";
RETURN_IF_NOT_OK(this->Init());
// Reset BufferID
buffer_id_ = 0;
// Create new generator object
RETURN_IF_NOT_OK(CreateGeneratorObject());
if (this->op_total_repeats() < 0) { if (this->op_total_repeats() < 0) {
// Wake up master thread // Wake up master thread
wp_.Set(); wp_.Set();


+ 53
- 45
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h View File

@@ -26,6 +26,7 @@
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/pipeline_op.h" #include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/wait_post.h" #include "minddata/dataset/util/wait_post.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"


@@ -35,47 +36,47 @@ namespace mindspore {
namespace dataset { namespace dataset {
#pragma GCC visibility push(hidden) #pragma GCC visibility push(hidden)


class GeneratorOp : public PipelineOp {
class GeneratorOp : public PipelineOp, public RandomAccessOp {
public: public:
class Builder { class Builder {
public: public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
/// Builder constructor. Creates the builder object.
/// \note No default args
/// \return This is a constructor.
Builder(); Builder();


~Builder() = default; ~Builder() = default;


// Setter method.
// @return Builder setter method returns reference to the builder.
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetGeneratorFunction(py::function generator_function) { Builder &SetGeneratorFunction(py::function generator_function) {
build_generator_function_ = generator_function; build_generator_function_ = generator_function;
return *this; return *this;
} }


// Setter method.
// @return Builder setter method returns reference to the builder.
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetColumnNames(const std::vector<std::string> &column_names) { Builder &SetColumnNames(const std::vector<std::string> &column_names) {
build_column_names_ = column_names; build_column_names_ = column_names;
return *this; return *this;
} }


// Setter method.
// @return Builder setter method returns reference to the builder.
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetColumnTypes(const std::vector<DataType> &column_types) { Builder &SetColumnTypes(const std::vector<DataType> &column_types) {
build_column_types_ = column_types; build_column_types_ = column_types;
return *this; return *this;
} }


// Setter method.
// @return Builder setter method returns reference to the builder.
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetPrefetchSize(int32_t prefetch_size) { Builder &SetPrefetchSize(int32_t prefetch_size) {
build_prefetch_size_ = prefetch_size; build_prefetch_size_ = prefetch_size;
return *this; return *this;
} }


// The builder "build" method creates the final object.
// @return shared_ptr to the new GeneratorOp object
/// The builder "build" method creates the final object.
/// \return shared_ptr to the new GeneratorOp object
Status Build(std::shared_ptr<GeneratorOp> *); Status Build(std::shared_ptr<GeneratorOp> *);


private: private:
@@ -94,56 +95,53 @@ class GeneratorOp : public PipelineOp {


GeneratorOp(py::function generator_function, std::vector<std::string> column_names, GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size, std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size,
int64_t pre_counter_size = 0);
std::shared_ptr<SamplerRT> sampler);


~GeneratorOp();
~GeneratorOp() = default;


// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
/// A print method typically used for debugging
/// \param out - The output stream to write output to
/// \param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override; void Print(std::ostream &out, bool show_all) const override;


// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param generator_op - reference to the GeneratorOp to display
// @return - the output stream must be returned
/// << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out - reference to the output stream being overloaded
/// \param generator_op - reference to the GeneratorOp to display
/// \return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) { friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) {
generator_op.Print(out, false); generator_op.Print(out, false);
return out; return out;
} }


// Class functor operator () override.
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work.
// @return Status The status code returned
/// Class functor operator () override.
/// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
/// provide the master loop that drives the logic for performing the work.
/// \return Status The status code returned
Status operator()() override; Status operator()() override;


// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
// @return Status The status code returned
/// Overrides base class reset method. When an operator does a reset, it cleans up any state
/// info from it's previous execution and then initializes itself so that it can be executed
/// again.
/// \return Status The status code returned
Status Reset() override; Status Reset() override;


// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
/// Base-class override for NodePass visitor acceptor.
/// \param p - Pointer to the NodePass to be accepted.
/// \param modified - Whether this node visit modified the pipeline.
/// \return - Status of the node visit.
Status Accept(NodePass *p, bool *const modified) override; Status Accept(NodePass *p, bool *const modified) override;


// Op name getter
// @return Name of the current Op
/// Op name getter
/// \return Name of the current Op
std::string Name() const override { return "GeneratorOp"; } std::string Name() const override { return "GeneratorOp"; }


Status Init();

private: private:
py::function generator_function_; py::function generator_function_;
std::vector<std::string> column_names_; std::vector<std::string> column_names_;
std::vector<DataType> column_types_; std::vector<DataType> column_types_;
int32_t prefetch_size_; int32_t prefetch_size_;
int32_t buffer_size_; int32_t buffer_size_;
int64_t pre_counter_size_;
int64_t generator_counter_; int64_t generator_counter_;


py::object generator_; py::object generator_;
@@ -151,15 +149,25 @@ class GeneratorOp : public PipelineOp {


WaitPost wp_; WaitPost wp_;


void Dealloc() noexcept;

Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row);


Status FillBuffer(TensorQTable *tt); Status FillBuffer(TensorQTable *tt);


// Private function for computing the assignment of the column name map.
// @return - Status
/// Private function for computing the assignment of the column name map.
/// \return - Status
Status ComputeColMap() override; Status ComputeColMap() override;

/// Initialize Sampler, calls sampler->Init() within
/// \return Status The status code returned
Status InitSampler();

/// Create new Generator object from the generator function
/// \return Status The status code returned
Status CreateGeneratorObject();

/// Initialize GeneratorOp
/// \return Status The status code returned
Status Init();
}; };


#pragma GCC visibility pop #pragma GCC visibility pop


+ 15
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -177,8 +177,21 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows); child_num_rows = child_[0]->CalculateNumSamples(num_rows);
} }
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
int64_t num_per_shard = std::ceil(child_num_rows * 1.0 / num_devices_);
return std::min(num_samples, num_per_shard);
int64_t remainder = (child_num_rows + offset_) % num_devices_;
int64_t shard_size = (child_num_rows + offset_) / num_devices_;
if (offset_ != -1 || !even_dist_) {
if (offset_ == -1) offset_ = 0;
if (device_id_ < remainder) shard_size++;
if (device_id_ < offset_) shard_size--;
} else {
offset_ = 0;
shard_size = (child_num_rows + num_devices_ - 1) / num_devices_;
}
// add 1 to an empty shard
// this logic is needed to follow the logic in initSampler that is written for ConcatDataset
if (shard_size == 0) shard_size++;

return std::min(num_samples, shard_size);
} }


void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {


+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h View File

@@ -48,6 +48,10 @@ class RandomAccessOp {
// default destructor // default destructor
virtual ~RandomAccessOp() = default; virtual ~RandomAccessOp() = default;


/// Set num_rows
/// \param num_rows
void SetNumRows(int64_t num_rows) { num_rows_ = num_rows; }

protected: protected:
// The amount of rows in the dataset itself. This is the before-sampling value, the // The amount of rows in the dataset itself. This is the before-sampling value, the
// total count of rows. A sampler may choose to sample less than this amount. // total count of rows. A sampler may choose to sample less than this amount.


+ 6
- 6
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc View File

@@ -95,19 +95,19 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size


// calculate the size of the shard // calculate the size of the shard
int64_t shard_dataset_size = 0; int64_t shard_dataset_size = 0;
if (sampler_ != nullptr) {
std::shared_ptr<DistributedSamplerRT> sampler_rt =
std::static_pointer_cast<DistributedSamplerRT>(sampler_->SamplerBuild());
std::shared_ptr<DistributedSamplerRT> sampler_rt =
sampler_ ? std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_->SamplerBuild()) : nullptr;
if (sampler_rt != nullptr) {
sampler_rt->SetNumRowsInDataset(total_dataset_size); sampler_rt->SetNumRowsInDataset(total_dataset_size);
sampler_rt->InitSampler(); sampler_rt->InitSampler();


// (total_size % num_shards != 0) & shard_id >= (remainder) ? CalculateNumSamples()-1 : CalculateNumSamples() // (total_size % num_shards != 0) & shard_id >= (remainder) ? CalculateNumSamples()-1 : CalculateNumSamples()
// example: 23 rows, 10 shards --> shard sizes = {3,3,3,2,2,2,2,2,2,2} // example: 23 rows, 10 shards --> shard sizes = {3,3,3,2,2,2,2,2,2,2}
if ((sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum()) > 0 &&
if ((sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum()) >= 0 &&
sampler_rt->GetDeviceID() >= (sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum())) { sampler_rt->GetDeviceID() >= (sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum())) {
shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples()) - 1;
shard_dataset_size = sampler_rt->GetNumSamples() / sampler_rt->GetDeviceNum();
} else { } else {
shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples());
shard_dataset_size = sampler_rt->GetNumSamples() / sampler_rt->GetDeviceNum() + 1;
} }
} else { } else {
shard_dataset_size = total_dataset_size; shard_dataset_size = total_dataset_size;


+ 35
- 13
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc View File

@@ -18,6 +18,7 @@


#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h"
@@ -29,24 +30,27 @@ namespace mindspore {
namespace dataset { namespace dataset {


GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types)
const std::vector<DataType> &column_types, int64_t source_len,
std::shared_ptr<SamplerObj> sampler)
: MappableSourceNode(), : MappableSourceNode(),
generator_function_(generator_function), generator_function_(generator_function),
column_names_(column_names), column_names_(column_names),
column_types_(column_types), column_types_(column_types),
reset_ancestor_(nullptr) {}
reset_ancestor_(nullptr),
sampler_(std::move(sampler)),
source_len_(source_len) {}


GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema,
int64_t source_len, std::shared_ptr<SamplerObj> sampler)
: MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {} : MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {}


std::shared_ptr<DatasetNode> GeneratorNode::Copy() { std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
std::shared_ptr<GeneratorNode> node; std::shared_ptr<GeneratorNode> node;
if (schema_ == nullptr) { if (schema_ == nullptr) {
node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_);
node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_, source_len_, sampler_);
} else { } else {
node = std::make_shared<GeneratorNode>(generator_function_, schema_);
node = std::make_shared<GeneratorNode>(generator_function_, schema_, source_len_, sampler_);
} }
node->SetGeneratorDatasetSize(dataset_size_);
return node; return node;
} }


@@ -69,17 +73,14 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
column_types_.push_back((col.type())); column_types_.push_back((col.type()));
} }
} }
std::shared_ptr<SamplerRT> sampler_rt = sampler_ ? sampler_->SamplerBuild() : nullptr;


// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
// GeneratorOp internally. Here it is given a zero which is the default in generator builder // GeneratorOp internally. Here it is given a zero which is the default in generator builder
std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0, std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0,
rows_per_buffer_, connector_que_size_, dataset_size_);

// Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init
// needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private).
// This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which
// best be delivered when the test cases for this api is ready.
RETURN_IF_NOT_OK(op->Init());
rows_per_buffer_, connector_que_size_, sampler_rt);
// set the number of rows from source length
op->SetNumRows(source_len_);


// Add this GeneratorOp to its RepeatOp/EpochCtrlOp ancestor's EOE list. // Add this GeneratorOp to its RepeatOp/EpochCtrlOp ancestor's EOE list.
// When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list. // When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list.
@@ -118,5 +119,26 @@ Status GeneratorNode::AcceptAfter(IRNodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor // Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<GeneratorNode>(), modified); return p->VisitAfter(shared_from_base<GeneratorNode>(), modified);
} }

Status GeneratorNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
if (source_len_ == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
dataset_size_ = *dataset_size;
return Status::OK();
} else {
int64_t sample_size;
int64_t num_rows;
num_rows = source_len_;
sample_size = sampler_ ? sampler_->SamplerBuild()->CalculateNumSamples(num_rows) : num_rows;
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 13
- 7
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h View File

@@ -34,10 +34,11 @@ class GeneratorNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types);
const std::vector<DataType> &column_types, int64_t source_len, std::shared_ptr<SamplerObj> sampler);


/// \brief Constructor /// \brief Constructor
GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema);
GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema, int64_t source_len,
std::shared_ptr<SamplerObj> sampler);


/// \brief Destructor /// \brief Destructor
~GeneratorNode() = default; ~GeneratorNode() = default;
@@ -67,11 +68,6 @@ class GeneratorNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully /// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override; Status GetShardId(int32_t *shard_id) override;


/// \brief Setter for DatasetSize in GeneratorNode
/// \param[in] sz dataset size to set
/// \return void
void SetGeneratorDatasetSize(int64_t sz) { dataset_size_ = sz; }

bool IsSizeDefined() override { return false; } bool IsSizeDefined() override { return false; }


/// \brief Record the vector of Repeat/EpochCtrl nodes that are ancestors of this node /// \brief Record the vector of Repeat/EpochCtrl nodes that are ancestors of this node
@@ -82,6 +78,14 @@ class GeneratorNode : public MappableSourceNode {
reset_ancestor_ = src; reset_ancestor_ = src;
return Status::OK(); return Status::OK();
} }
/// Returns the dataset size of GeneratorOp. If is mappable (sampler isn not null), the sampler is used.
/// Otherwise, a dry run is needed.
/// \param[in] size_getter TreeConsumer to be used for a dryrun
/// \param[in] estimate
/// \param[out] dataset_size
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;


/// \brief Getter functions /// \brief Getter functions
const py::function &GeneratorFunction() const { return generator_function_; } const py::function &GeneratorFunction() const { return generator_function_; }
@@ -102,6 +106,8 @@ class GeneratorNode : public MappableSourceNode {
std::vector<DataType> column_types_; std::vector<DataType> column_types_;
std::shared_ptr<SchemaObj> schema_; std::shared_ptr<SchemaObj> schema_;
std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass
std::shared_ptr<SamplerObj> sampler_;
int64_t source_len_; // Length of the dataset source provided by the user, -1 means it's unknown


/// \brief Base-class override for accepting IRNodePass visitor /// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit /// \param[in] p The node to visit


+ 33
- 30
mindspore/dataset/engine/datasets.py View File

@@ -3427,23 +3427,31 @@ def _py_sampler_fn(sampler, num_samples, dataset):
yield tuple([np.array(x, copy=False) for x in val]) yield tuple([np.array(x, copy=False) for x in val])




def _cpp_sampler_fn(sampler, dataset):
def _cpp_sampler_fn(sample_ids, dataset):
""" """
Generator function wrapper for mappable dataset with cpp sampler. Generator function wrapper for mappable dataset with cpp sampler.
""" """
indices = sampler.get_indices()
for i in indices:
if not isinstance(sample_ids, np.ndarray):
raise RuntimeError("Sample IDs are not in a numpy array.")
if sample_ids.size == 0:
raise RuntimeError("Sampler passed an empty sample IDs list.")

for i in sample_ids:
val = dataset[i] val = dataset[i]
# convert output tensors to ndarrays # convert output tensors to ndarrays
yield tuple([np.array(x, copy=False) for x in val]) yield tuple([np.array(x, copy=False) for x in val])




def _cpp_sampler_fn_mp(sampler, sample_fn):
def _cpp_sampler_fn_mp(sample_ids, sample_fn):
""" """
Multiprocessing generator function wrapper for mappable dataset with cpp sampler. Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
""" """
indices = sampler.get_indices()
return sample_fn.process(indices)
if not isinstance(sample_ids, np.ndarray):
raise RuntimeError("Sample IDs are not in a numpy array.")
if sample_ids.size == 0:
raise RuntimeError("Sampler passed an empty sample IDs list.")

return sample_fn.process(sample_ids)




def _py_sampler_fn_mp(sampler, num_samples, sample_fn): def _py_sampler_fn_mp(sampler, num_samples, sample_fn):
@@ -3811,18 +3819,9 @@ class GeneratorDataset(MappableDataset):
self.schema = Schema(schema) self.schema = Schema(schema)
# Move get dataset_size by len from parse to here, because self.source will # Move get dataset_size by len from parse to here, because self.source will
# lose attribution of '__len__' after deepcopy. # lose attribution of '__len__' after deepcopy.
self.dataset_size = None
self.source_len = -1 # unknown
if hasattr(self.source, "__len__"): if hasattr(self.source, "__len__"):
if not self.num_shards:
self.dataset_size = len(self.source)
else:
self.dataset_size = math.ceil(len(self.source) / self.num_shards)

rows_from_sampler = self._get_sampler_dataset_size()
if self.num_samples is not None and self.num_samples < rows_from_sampler:
rows_from_sampler = self.num_samples
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
self.source_len = len(self.source)


def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
if id(self) in memodict: if id(self) in memodict:
@@ -3839,6 +3838,7 @@ class GeneratorDataset(MappableDataset):
new_op.num_samples = copy.deepcopy(self.num_samples, memodict) new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
new_op.sampler = copy.deepcopy(self.sampler) new_op.sampler = copy.deepcopy(self.sampler)
new_op.dataset_size = self.dataset_size new_op.dataset_size = self.dataset_size
new_op.source_len = self.source_len
new_op.saved_output_types = self.saved_output_types new_op.saved_output_types = self.saved_output_types
new_op.saved_output_shapes = self.saved_output_shapes new_op.saved_output_shapes = self.saved_output_shapes
if hasattr(self, "__total_batch__"): if hasattr(self, "__total_batch__"):
@@ -3847,22 +3847,28 @@ class GeneratorDataset(MappableDataset):
if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler, samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler, samplers.Sampler)): samplers.WeightedRandomSampler, samplers.Sampler)):
sampler_instance = new_op.sampler.create()
sampler_instance.set_num_rows(len(self.source))
sampler_instance.initialize()
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn))
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else: else:
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
else: else:
# the sampler provided is not a built-in sampler, it is a list of sample_ids
new_op.sample_ids = new_op.sampler
# since list of sample_ids are not passed to c++, we need to find the proper len here
new_op.source_len = min(self.source_len, len(new_op.sample_ids)) if self.source_len != -1 else len(
new_op.sample_ids)
new_op.source_len = min(self.source_len,
new_op.num_samples) if new_op.num_samples is not None else new_op.source_len
new_op.sampler = None
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn))
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sample_ids, new_op.num_samples, sample_fn))
else: else:
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
new_op.source = (lambda: _py_sampler_fn(new_op.sample_ids, new_op.num_samples, self.source))
else: else:
try: try:
new_op.sampler = None
iter(self.source) iter(self.source)
except TypeError: except TypeError:
# Use generator function if input callable # Use generator function if input callable
@@ -3881,16 +3887,13 @@ class GeneratorDataset(MappableDataset):
return self.sampler.is_sharded() return self.sampler.is_sharded()


def parse(self, children=None): def parse(self, children=None):
if self.dataset_size is None:
self.dataset_size = -1
if self.schema is None: if self.schema is None:
return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize(
self.dataset_size) \
.SetNumWorkers(self.num_parallel_workers)
return cde.GeneratorNode(self.source, self.column_names, self.column_types,
self.source_len, self.sampler).SetNumWorkers(self.num_parallel_workers)
schema = self.schema schema = self.schema
if isinstance(schema, Schema): if isinstance(schema, Schema):
schema = self.schema.cpp_schema schema = self.schema.cpp_schema
return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(self.dataset_size).SetNumWorkers(
return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler).SetNumWorkers(
self.num_parallel_workers) self.num_parallel_workers)






Loading…
Cancel
Save