Browse Source

add num_epochs to CPP API Dataset::CreateIterator(columns,num_epochs)

tags/v1.2.0-rc1
Zirui Wu 4 years ago
parent
commit
fa553956ef
6 changed files with 55 additions and 9 deletions
  1. +2
    -2
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +2
    -2
      mindspore/ccsrc/minddata/dataset/api/iterator.cc
  3. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  4. +3
    -1
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  5. +2
    -1
      mindspore/ccsrc/minddata/dataset/include/iterator.h
  6. +44
    -1
      tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc

+ 2
- 2
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -102,7 +102,7 @@ namespace mindspore {
namespace dataset { namespace dataset {


// Function to create the iterator, which will build and launch the execution tree. // Function to create the iterator, which will build and launch the execution tree.
std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) {
std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns, int32_t num_epochs) {
std::shared_ptr<Iterator> iter; std::shared_ptr<Iterator> iter;
try { try {
auto ds = shared_from_this(); auto ds = shared_from_this();
@@ -114,7 +114,7 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
} }


iter = std::make_shared<Iterator>(); iter = std::make_shared<Iterator>();
Status rc = iter->BuildAndLaunchTree(ds);
Status rc = iter->BuildAndLaunchTree(ds, num_epochs);
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "CreateIterator failed." << rc; MS_LOG(ERROR) << "CreateIterator failed." << rc;
return nullptr; return nullptr;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/api/iterator.cc View File

@@ -56,10 +56,10 @@ void Iterator::Stop() {
} }


// Function to build and launch the execution tree. // Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
runtime_context_ = std::make_unique<NativeRuntimeContext>(); runtime_context_ = std::make_unique<NativeRuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init()); RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>();
auto consumer = std::make_unique<IteratorConsumer>(num_epochs);
consumer_ = consumer.get(); consumer_ = consumer.get();
RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
runtime_context_->AssignConsumer(std::move(consumer)); runtime_context_->AssignConsumer(std::move(consumer));


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -410,12 +410,12 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
return Status::OK(); return Status::OK();
} }
if (children_.size() == 1) { if (children_.size() == 1) {
return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size);
return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size);
} else if (children_.size() > 1) { } else if (children_.size() > 1) {
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
// always be in front of the child_ structure, so we get the dataset size from the last child. // always be in front of the child_ structure, so we get the dataset size from the last child.
return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size);
return children_.back()->GetDatasetSize(size_getter, estimate, dataset_size);
} else { } else {
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
} }


+ 3
- 1
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -141,8 +141,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {


/// \brief Function to create an Iterator over the Dataset pipeline /// \brief Function to create an Iterator over the Dataset pipeline
/// \param[in] columns List of columns to be used to specify the order of columns /// \param[in] columns List of columns to be used to specify the order of columns
/// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
/// An empty row is returned at the end of each epoch
/// \return Shared pointer to the Iterator /// \return Shared pointer to the Iterator
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1);


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Function to transfer data through a device. /// \brief Function to transfer data through a device.


+ 2
- 1
mindspore/ccsrc/minddata/dataset/include/iterator.h View File

@@ -51,8 +51,9 @@ class Iterator {


/// \brief Method for building and launching the pipeline. /// \brief Method for building and launching the pipeline.
/// \param[in] ops - a vector of DatasetOp in the data pipeline. /// \param[in] ops - a vector of DatasetOp in the data pipeline.
/// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode, default -1, infinite epochs
/// \return - a Status error code, returns OK if no error encountered. /// \return - a Status error code, returns OK if no error encountered.
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs);


/// \brief Function to get the next row from the data pipeline. /// \brief Function to get the next row from the data pipeline.
/// \note Type of return data is a map(with column name). /// \note Type of return data is a map(with column name).


+ 44
- 1
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc View File

@@ -76,7 +76,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) {
// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
// Only select "image" column and drop others // Only select "image" column and drop others
std::vector<std::string> columns = {"image"}; std::vector<std::string> columns = {"image"};
std::shared_ptr<Iterator> iter = ds->CreateIterator(columns);
std::shared_ptr<Iterator> iter = ds->CreateIterator(columns, -1);
EXPECT_NE(iter, nullptr); EXPECT_NE(iter, nullptr);


// Iterate the dataset and get each row // Iterate the dataset and get each row
@@ -195,3 +195,46 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) {
std::shared_ptr<Iterator> iter = ds->CreateIterator(columns); std::shared_ptr<Iterator> iter = ds->CreateIterator(columns);
EXPECT_EQ(iter, nullptr); EXPECT_EQ(iter, nullptr);
} }

TEST_F(MindDataTestPipeline, TestIteratorNumEpoch) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorNumEpoch.";

std::shared_ptr<SchemaObj> schema = Schema();
int32_t random_data_num_row = 2;
int32_t num_epochs = 3;
ASSERT_OK(schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2}));
std::shared_ptr<Dataset> ds = RandomData(random_data_num_row, schema)->SetNumWorkers(1);

std::shared_ptr<Iterator> iter = ds->CreateIterator({}, num_epochs);
ASSERT_NE(iter, nullptr); // should terminate test case if iterator is null
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;

int32_t inner_row_cnt = 0;
int32_t total_row_cnt = 0;
for (int32_t i = 0; i < num_epochs; i++) {
ASSERT_TRUE(iter->GetNextRow(&row));
inner_row_cnt = 0;
while (row.size() != 0) {
ASSERT_TRUE(iter->GetNextRow(&row));
++inner_row_cnt;
++total_row_cnt;
}
EXPECT_EQ(inner_row_cnt, random_data_num_row);
}
EXPECT_EQ(total_row_cnt, random_data_num_row * num_epochs);
// this will go beyond the random_data_num_row*num_epoch limit, hence error code is expected
EXPECT_FALSE(iter->GetNextRow(&row));
// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestIteratorNumEpochFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorNumEpochFail.";

std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2}));
std::shared_ptr<Dataset> ds = RandomData(3, schema)->SetNumWorkers(1);
// expect nullptr due to incorrect num_epochs value.
EXPECT_EQ(ds->CreateIterator({}, 0), nullptr);
EXPECT_EQ(ds->CreateIterator({}, -2), nullptr);
}

Loading…
Cancel
Save