From fa553956ef93705b80103009ce34d3ea98a1669f Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Wed, 13 Jan 2021 11:21:10 -0500 Subject: [PATCH] add num_epochs to CPP API Dataset::CreateIterator(columns,num_epochs) --- .../ccsrc/minddata/dataset/api/datasets.cc | 4 +- .../ccsrc/minddata/dataset/api/iterator.cc | 4 +- .../engine/ir/datasetops/dataset_node.cc | 4 +- .../ccsrc/minddata/dataset/include/datasets.h | 4 +- .../ccsrc/minddata/dataset/include/iterator.h | 3 +- .../dataset/c_api_dataset_iterator_test.cc | 45 ++++++++++++++++++- 6 files changed, 55 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 15300fb2a5..9149365b30 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -102,7 +102,7 @@ namespace mindspore { namespace dataset { // Function to create the iterator, which will build and launch the execution tree. -std::shared_ptr Dataset::CreateIterator(std::vector columns) { +std::shared_ptr Dataset::CreateIterator(std::vector columns, int32_t num_epochs) { std::shared_ptr iter; try { auto ds = shared_from_this(); @@ -114,7 +114,7 @@ std::shared_ptr Dataset::CreateIterator(std::vector colum } iter = std::make_shared(); - Status rc = iter->BuildAndLaunchTree(ds); + Status rc = iter->BuildAndLaunchTree(ds, num_epochs); if (rc.IsError()) { MS_LOG(ERROR) << "CreateIterator failed." << rc; return nullptr; diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc index 9daa7403ad..93abd7e195 100644 --- a/mindspore/ccsrc/minddata/dataset/api/iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -56,10 +56,10 @@ void Iterator::Stop() { } // Function to build and launch the execution tree. -Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { +Status Iterator::BuildAndLaunchTree(std::shared_ptr ds, int32_t num_epochs) { runtime_context_ = std::make_unique(); RETURN_IF_NOT_OK(runtime_context_->Init()); - auto consumer = std::make_unique(); + auto consumer = std::make_unique(num_epochs); consumer_ = consumer.get(); RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); runtime_context_->AssignConsumer(std::move(consumer)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 95d17a7686..62320394f7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -410,12 +410,12 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr &siz return Status::OK(); } if (children_.size() == 1) { - return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size); + return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size); } else if (children_.size() > 1) { // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will // always be in front of the child_ structure, so we get the dataset size from the last child. - return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size); + return children_.back()->GetDatasetSize(size_getter, estimate, dataset_size); } else { RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index b5ad1ec77d..49887e5606 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -141,8 +141,10 @@ class Dataset : public std::enable_shared_from_this { /// \brief Function to create an Iterator over the Dataset pipeline /// \param[in] columns List of columns to be used to specify the order of columns + /// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs. + /// An empty row is returned at the end of each epoch /// \return Shared pointer to the Iterator - std::shared_ptr CreateIterator(std::vector columns = {}); + std::shared_ptr CreateIterator(std::vector columns = {}, int32_t num_epochs = -1); #ifndef ENABLE_ANDROID /// \brief Function to transfer data through a device. diff --git a/mindspore/ccsrc/minddata/dataset/include/iterator.h b/mindspore/ccsrc/minddata/dataset/include/iterator.h index f4a07c36cd..dbd6ddd0fa 100644 --- a/mindspore/ccsrc/minddata/dataset/include/iterator.h +++ b/mindspore/ccsrc/minddata/dataset/include/iterator.h @@ -51,8 +51,9 @@ class Iterator { /// \brief Method for building and launching the pipeline. /// \param[in] ops - a vector of DatasetOp in the data pipeline. + /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode, default -1, infinite epochs /// \return - a Status error code, returns OK if no error encountered. - Status BuildAndLaunchTree(std::shared_ptr ds); + Status BuildAndLaunchTree(std::shared_ptr ds, int32_t num_epochs); /// \brief Function to get the next row from the data pipeline. /// \note Type of return data is a map(with column name). diff --git a/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc b/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc index 0250d39da0..6bcc185525 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc @@ -76,7 +76,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) { // Create an iterator over the result of the above dataset // Only select "image" column and drop others std::vector columns = {"image"}; - std::shared_ptr iter = ds->CreateIterator(columns); + std::shared_ptr iter = ds->CreateIterator(columns, -1); EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row @@ -195,3 +195,46 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) { std::shared_ptr iter = ds->CreateIterator(columns); EXPECT_EQ(iter, nullptr); } + +TEST_F(MindDataTestPipeline, TestIteratorNumEpoch) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorNumEpoch."; + + std::shared_ptr schema = Schema(); + int32_t random_data_num_row = 2; + int32_t num_epochs = 3; + ASSERT_OK(schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2})); + std::shared_ptr ds = RandomData(random_data_num_row, schema)->SetNumWorkers(1); + + std::shared_ptr iter = ds->CreateIterator({}, num_epochs); + ASSERT_NE(iter, nullptr); // should terminate test case if iterator is null + std::unordered_map> row; + + int32_t inner_row_cnt = 0; + int32_t total_row_cnt = 0; + for (int32_t i = 0; i < num_epochs; i++) { + ASSERT_TRUE(iter->GetNextRow(&row)); + inner_row_cnt = 0; + while (row.size() != 0) { + ASSERT_TRUE(iter->GetNextRow(&row)); + ++inner_row_cnt; + ++total_row_cnt; + } + EXPECT_EQ(inner_row_cnt, random_data_num_row); + } + EXPECT_EQ(total_row_cnt, random_data_num_row * num_epochs); + // this will go beyond the random_data_num_row*num_epoch limit, hence error code is expected + EXPECT_FALSE(iter->GetNextRow(&row)); + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestIteratorNumEpochFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorNumEpochFail."; + + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2})); + std::shared_ptr ds = RandomData(3, schema)->SetNumWorkers(1); + // expect nullptr due to incorrect num_epochs value. + EXPECT_EQ(ds->CreateIterator({}, 0), nullptr); + EXPECT_EQ(ds->CreateIterator({}, -2), nullptr); +}