Browse Source

[MD] Skip returning IR when meet nullptr in ConcatDataset

tags/v1.1.0
luoyang 5 years ago
parent
commit
dcb758bc77
2 changed files with 44 additions and 3 deletions
  1. +4
    -3
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +40
    -0
      tests/ut/cpp/dataset/c_api_dataset_ops_test.cc

+ 4
- 3
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -478,9 +478,10 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset(


ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<DatasetNode>> all_datasets; std::vector<std::shared_ptr<DatasetNode>> all_datasets;
(void)std::transform(
datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); });
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
return (dataset != nullptr) ? dataset->IRNode() : nullptr;
});


auto ds = std::make_shared<ConcatNode>(all_datasets); auto ds = std::make_shared<ConcatNode>(all_datasets);




+ 40
- 0
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc View File

@@ -340,6 +340,46 @@ TEST_F(MindDataTestPipeline, TestConcatFail2) {
EXPECT_EQ(iter, nullptr); EXPECT_EQ(iter, nullptr);
} }


TEST_F(MindDataTestPipeline, TestConcatFail3) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatFail3.";
// This case is expected to fail because the input dataset is nullptr.

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);

// Create a Concat operation on the ds
// Input dataset to concat is null
ds = ds->Concat({nullptr});
EXPECT_NE(ds, nullptr);

// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid Op input
EXPECT_EQ(iter, nullptr);
}

TEST_F(MindDataTestPipeline, TestConcatFail4) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatFail4.";
// This case is expected to fail because the input dataset is nullptr.

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);

// Create a Concat operation on the ds
// Input dataset to concat is null
ds = ds + nullptr;
EXPECT_NE(ds, nullptr);

// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid Op input
EXPECT_EQ(iter, nullptr);
}

TEST_F(MindDataTestPipeline, TestConcatSuccess) { TEST_F(MindDataTestPipeline, TestConcatSuccess) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatSuccess."; MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatSuccess.";




Loading…
Cancel
Save