| @@ -102,7 +102,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Function to create the iterator, which will build and launch the execution tree. | // Function to create the iterator, which will build and launch the execution tree. | ||||
| std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) { | |||||
| std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns, int32_t num_epochs) { | |||||
| std::shared_ptr<Iterator> iter; | std::shared_ptr<Iterator> iter; | ||||
| try { | try { | ||||
| auto ds = shared_from_this(); | auto ds = shared_from_this(); | ||||
| @@ -114,7 +114,7 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum | |||||
| } | } | ||||
| iter = std::make_shared<Iterator>(); | iter = std::make_shared<Iterator>(); | ||||
| Status rc = iter->BuildAndLaunchTree(ds); | |||||
| Status rc = iter->BuildAndLaunchTree(ds, num_epochs); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "CreateIterator failed." << rc; | MS_LOG(ERROR) << "CreateIterator failed." << rc; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -56,10 +56,10 @@ void Iterator::Stop() { | |||||
| } | } | ||||
| // Function to build and launch the execution tree. | // Function to build and launch the execution tree. | ||||
| Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | |||||
| Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) { | |||||
| runtime_context_ = std::make_unique<NativeRuntimeContext>(); | runtime_context_ = std::make_unique<NativeRuntimeContext>(); | ||||
| RETURN_IF_NOT_OK(runtime_context_->Init()); | RETURN_IF_NOT_OK(runtime_context_->Init()); | ||||
| auto consumer = std::make_unique<IteratorConsumer>(); | |||||
| auto consumer = std::make_unique<IteratorConsumer>(num_epochs); | |||||
| consumer_ = consumer.get(); | consumer_ = consumer.get(); | ||||
| RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); | RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); | ||||
| runtime_context_->AssignConsumer(std::move(consumer)); | runtime_context_->AssignConsumer(std::move(consumer)); | ||||
| @@ -410,12 +410,12 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| if (children_.size() == 1) { | if (children_.size() == 1) { | ||||
| return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size); | |||||
| return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size); | |||||
| } else if (children_.size() > 1) { | } else if (children_.size() > 1) { | ||||
| // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. | // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. | ||||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | ||||
| // always be in front of the child_ structure, so we get the dataset size from the last child. | // always be in front of the child_ structure, so we get the dataset size from the last child. | ||||
| return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size); | |||||
| return children_.back()->GetDatasetSize(size_getter, estimate, dataset_size); | |||||
| } else { | } else { | ||||
| RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | ||||
| } | } | ||||
| @@ -141,8 +141,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// \brief Function to create an Iterator over the Dataset pipeline | /// \brief Function to create an Iterator over the Dataset pipeline | ||||
| /// \param[in] columns List of columns to be used to specify the order of columns | /// \param[in] columns List of columns to be used to specify the order of columns | ||||
| /// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs. | |||||
| /// An empty row is returned at the end of each epoch | |||||
| /// \return Shared pointer to the Iterator | /// \return Shared pointer to the Iterator | ||||
| std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}); | |||||
| std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Function to transfer data through a device. | /// \brief Function to transfer data through a device. | ||||
| @@ -51,8 +51,9 @@ class Iterator { | |||||
| /// \brief Method for building and launching the pipeline. | /// \brief Method for building and launching the pipeline. | ||||
| /// \param[in] ops - a vector of DatasetOp in the data pipeline. | /// \param[in] ops - a vector of DatasetOp in the data pipeline. | ||||
| /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode, default -1, infinite epochs | |||||
| /// \return - a Status error code, returns OK if no error encountered. | /// \return - a Status error code, returns OK if no error encountered. | ||||
| Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds); | |||||
| Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs); | |||||
| /// \brief Function to get the next row from the data pipeline. | /// \brief Function to get the next row from the data pipeline. | ||||
| /// \note Type of return data is a map(with column name). | /// \note Type of return data is a map(with column name). | ||||
| @@ -76,7 +76,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) { | |||||
| // Create an iterator over the result of the above dataset | // Create an iterator over the result of the above dataset | ||||
| // Only select "image" column and drop others | // Only select "image" column and drop others | ||||
| std::vector<std::string> columns = {"image"}; | std::vector<std::string> columns = {"image"}; | ||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(columns); | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(columns, -1); | |||||
| EXPECT_NE(iter, nullptr); | EXPECT_NE(iter, nullptr); | ||||
| // Iterate the dataset and get each row | // Iterate the dataset and get each row | ||||
| @@ -195,3 +195,46 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) { | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(columns); | std::shared_ptr<Iterator> iter = ds->CreateIterator(columns); | ||||
| EXPECT_EQ(iter, nullptr); | EXPECT_EQ(iter, nullptr); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestIteratorNumEpoch) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorNumEpoch."; | |||||
| std::shared_ptr<SchemaObj> schema = Schema(); | |||||
| int32_t random_data_num_row = 2; | |||||
| int32_t num_epochs = 3; | |||||
| ASSERT_OK(schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2})); | |||||
| std::shared_ptr<Dataset> ds = RandomData(random_data_num_row, schema)->SetNumWorkers(1); | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator({}, num_epochs); | |||||
| ASSERT_NE(iter, nullptr); // should terminate test case if iterator is null | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| int32_t inner_row_cnt = 0; | |||||
| int32_t total_row_cnt = 0; | |||||
| for (int32_t i = 0; i < num_epochs; i++) { | |||||
| ASSERT_TRUE(iter->GetNextRow(&row)); | |||||
| inner_row_cnt = 0; | |||||
| while (row.size() != 0) { | |||||
| ASSERT_TRUE(iter->GetNextRow(&row)); | |||||
| ++inner_row_cnt; | |||||
| ++total_row_cnt; | |||||
| } | |||||
| EXPECT_EQ(inner_row_cnt, random_data_num_row); | |||||
| } | |||||
| EXPECT_EQ(total_row_cnt, random_data_num_row * num_epochs); | |||||
| // this will go beyond the random_data_num_row*num_epoch limit, hence error code is expected | |||||
| EXPECT_FALSE(iter->GetNextRow(&row)); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestIteratorNumEpochFail) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorNumEpochFail."; | |||||
| std::shared_ptr<SchemaObj> schema = Schema(); | |||||
| ASSERT_OK(schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2})); | |||||
| std::shared_ptr<Dataset> ds = RandomData(3, schema)->SetNumWorkers(1); | |||||
| // expect nullptr due to incorrect num_epochs value. | |||||
| EXPECT_EQ(ds->CreateIterator({}, 0), nullptr); | |||||
| EXPECT_EQ(ds->CreateIterator({}, -2), nullptr); | |||||
| } | |||||