Browse Source

add num_epochs to CPP API Dataset::CreateIterator(columns,num_epochs)

tags/v1.2.0-rc1
Zirui Wu 4 years ago
parent
commit
fa553956ef
6 changed files with 55 additions and 9 deletions
  1. +2
    -2
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +2
    -2
      mindspore/ccsrc/minddata/dataset/api/iterator.cc
  3. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  4. +3
    -1
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  5. +2
    -1
      mindspore/ccsrc/minddata/dataset/include/iterator.h
  6. +44
    -1
      tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc

+ 2
- 2
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -102,7 +102,7 @@ namespace mindspore {
namespace dataset {

// Function to create the iterator, which will build and launch the execution tree.
std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) {
std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns, int32_t num_epochs) {
std::shared_ptr<Iterator> iter;
try {
auto ds = shared_from_this();
@@ -114,7 +114,7 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
}

iter = std::make_shared<Iterator>();
Status rc = iter->BuildAndLaunchTree(ds);
Status rc = iter->BuildAndLaunchTree(ds, num_epochs);
if (rc.IsError()) {
MS_LOG(ERROR) << "CreateIterator failed." << rc;
return nullptr;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/api/iterator.cc View File

@@ -56,10 +56,10 @@ void Iterator::Stop() {
}

// Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
runtime_context_ = std::make_unique<NativeRuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>();
auto consumer = std::make_unique<IteratorConsumer>(num_epochs);
consumer_ = consumer.get();
RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
runtime_context_->AssignConsumer(std::move(consumer));


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -410,12 +410,12 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
return Status::OK();
}
if (children_.size() == 1) {
return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size);
return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size);
} else if (children_.size() > 1) {
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
// always be in front of the child_ structure, so we get the dataset size from the last child.
return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size);
return children_.back()->GetDatasetSize(size_getter, estimate, dataset_size);
} else {
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
}


+ 3
- 1
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -141,8 +141,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {

/// \brief Function to create an Iterator over the Dataset pipeline
/// \param[in] columns List of columns to be used to specify the order of columns
/// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
/// An empty row is returned at the end of each epoch
/// \return Shared pointer to the Iterator
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1);

#ifndef ENABLE_ANDROID
/// \brief Function to transfer data through a device.


+ 2
- 1
mindspore/ccsrc/minddata/dataset/include/iterator.h View File

@@ -51,8 +51,9 @@ class Iterator {

/// \brief Method for building and launching the pipeline.
/// \param[in] ops - a vector of DatasetOp in the data pipeline.
/// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode, default -1, infinite epochs
/// \return - a Status error code, returns OK if no error encountered.
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs);

/// \brief Function to get the next row from the data pipeline.
/// \note Type of return data is a map(with column name).


+ 44
- 1
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc View File

@@ -76,7 +76,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) {
// Create an iterator over the result of the above dataset
// Only select "image" column and drop others
std::vector<std::string> columns = {"image"};
std::shared_ptr<Iterator> iter = ds->CreateIterator(columns);
std::shared_ptr<Iterator> iter = ds->CreateIterator(columns, -1);
EXPECT_NE(iter, nullptr);

// Iterate the dataset and get each row
@@ -195,3 +195,46 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) {
std::shared_ptr<Iterator> iter = ds->CreateIterator(columns);
EXPECT_EQ(iter, nullptr);
}

TEST_F(MindDataTestPipeline, TestIteratorNumEpoch) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorNumEpoch.";

std::shared_ptr<SchemaObj> schema = Schema();
int32_t random_data_num_row = 2;
int32_t num_epochs = 3;
ASSERT_OK(schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2}));
std::shared_ptr<Dataset> ds = RandomData(random_data_num_row, schema)->SetNumWorkers(1);

std::shared_ptr<Iterator> iter = ds->CreateIterator({}, num_epochs);
ASSERT_NE(iter, nullptr); // should terminate test case if iterator is null
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;

int32_t inner_row_cnt = 0;
int32_t total_row_cnt = 0;
for (int32_t i = 0; i < num_epochs; i++) {
ASSERT_TRUE(iter->GetNextRow(&row));
inner_row_cnt = 0;
while (row.size() != 0) {
ASSERT_TRUE(iter->GetNextRow(&row));
++inner_row_cnt;
++total_row_cnt;
}
EXPECT_EQ(inner_row_cnt, random_data_num_row);
}
EXPECT_EQ(total_row_cnt, random_data_num_row * num_epochs);
// this will go beyond the random_data_num_row*num_epoch limit, hence error code is expected
EXPECT_FALSE(iter->GetNextRow(&row));
// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestIteratorNumEpochFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorNumEpochFail.";

std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2}));
std::shared_ptr<Dataset> ds = RandomData(3, schema)->SetNumWorkers(1);
// expect nullptr due to incorrect num_epochs value.
EXPECT_EQ(ds->CreateIterator({}, 0), nullptr);
EXPECT_EQ(ds->CreateIterator({}, -2), nullptr);
}

Loading…
Cancel
Save