From ce419404cef7a76ccf282262276f0d6a3bd0aa06 Mon Sep 17 00:00:00 2001 From: luoyang Date: Tue, 1 Sep 2020 11:32:55 +0800 Subject: [PATCH] Add check for duplicate column for a-cpi --- .../ccsrc/minddata/dataset/api/datasets.cc | 21 ++++++++++++-- .../ccsrc/minddata/dataset/include/datasets.h | 2 +- .../ut/cpp/dataset/c_api_dataset_ops_test.cc | 28 +++++++++++++++++-- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 524f6cab68..2a31f30015 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -726,7 +726,7 @@ bool ValidateDatasetSampler(const std::string &dataset_name, const std::shared_p bool ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, const std::vector &columns) { if (columns.empty()) { - MS_LOG(ERROR) << dataset_name << ":" << column_param << " should not be empty"; + MS_LOG(ERROR) << dataset_name << ":" << column_param << " should not be empty string"; return false; } for (uint32_t i = 0; i < columns.size(); ++i) { @@ -1205,6 +1205,11 @@ bool CSVDataset::ValidateParams() { return false; } + if (find(column_defaults_.begin(), column_defaults_.end(), nullptr) != column_defaults_.end()) { + MS_LOG(ERROR) << "CSVDataset: column_default should not be null."; + return false; + } + if (!column_names_.empty()) { if (!ValidateDatasetColumnParam("CSVDataset", "column_names", column_names_)) { return false; @@ -1723,6 +1728,11 @@ bool BuildVocabDataset::ValidateParams() { << "but got [" << freq_range_.first << ", " << freq_range_.second << "]"; return false; } + if (!columns_.empty()) { + if (!ValidateDatasetColumnParam("BuildVocab", "columns", columns_)) { + return false; + } + } return true; } #endif @@ -1811,7 +1821,10 @@ ProjectDataset::ProjectDataset(const std::vector &columns) : column bool ProjectDataset::ValidateParams() { if (columns_.empty()) { - MS_LOG(ERROR) << "No columns are specified."; + MS_LOG(ERROR) << "ProjectDataset: No columns are specified."; + return false; + } + if (!ValidateDatasetColumnParam("ProjectDataset", "columns", columns_)) { return false; } return true; @@ -1949,6 +1962,10 @@ bool ZipDataset::ValidateParams() { MS_LOG(ERROR) << "Zip: dataset to zip are not specified."; return false; } + if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { + MS_LOG(ERROR) << "ZipDataset: zip dataset should not be null."; + return false; + } return true; } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index b1dee15ec8..de507ac8ba 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -329,7 +329,7 @@ std::shared_ptr TextFile(const std::vector &datase /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection" /// \param[in] mode Set the data list txt file to be readed -/// \param[in] class_indexing A str-to-int mapping from label name to index +/// \param[in] class_indexing A str-to-int mapping from label name to index, only valid in "Detection" task /// \param[in] decode Decode the images after reading /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index 57806837e2..18e0fe31cd 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -270,7 +270,7 @@ TEST_F(MindDataTestPipeline, TestConcatFail1) { ds2 = ds2->Rename({"image", "label"}, {"col1", "col2"}); EXPECT_NE(ds, nullptr); - // Create a Project operation on the ds + // Create a Concat operation on the ds // Name of datasets to concat doesn't not match ds = ds->Concat({ds2}); EXPECT_NE(ds, nullptr); @@ -295,7 +295,7 @@ TEST_F(MindDataTestPipeline, TestConcatFail2) { std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); - // Create a Project operation on the ds + // Create a Concat operation on the ds // Input dataset to concat is empty ds = ds->Concat({}); EXPECT_EQ(ds, nullptr); @@ -499,6 +499,30 @@ TEST_F(MindDataTestPipeline, TestProjectMap) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestProjectDuplicateColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProjectDuplicateColumn."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 3)); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_vertical_flip_op = vision::RandomVerticalFlip(0.5); + EXPECT_NE(random_vertical_flip_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_vertical_flip_op}, {}, {}, {"image", "label"}); + EXPECT_NE(ds, nullptr); + + // Create a Project operation on ds + std::vector column_project = {"image", "image"}; + + // Expect failure: duplicate project column name + ds = ds->Project(column_project); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestMapDuplicateColumn) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMapDuplicateColumn.";