Merge pull request !3342 from cathwong/ckw_c_api_skiptags/v0.7.0-beta
| @@ -27,6 +27,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/map_op.h" | #include "minddata/dataset/engine/datasetops/map_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | #include "minddata/dataset/engine/datasetops/repeat_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | #include "minddata/dataset/engine/datasetops/shuffle_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/skip_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/project_op.h" | #include "minddata/dataset/engine/datasetops/project_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | #include "minddata/dataset/engine/datasetops/zip_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| @@ -173,6 +174,20 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) { | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a SkipDataset. | |||||
| std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) { | |||||
| auto ds = std::make_shared<SkipDataset>(count); | |||||
| // Call derived class validation method. | |||||
| if (!ds->ValidateParams()) { | |||||
| return nullptr; | |||||
| } | |||||
| ds->children.push_back(shared_from_this()); | |||||
| return ds; | |||||
| } | |||||
| // Function to create a ProjectDataset. | // Function to create a ProjectDataset. | ||||
| std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) { | std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) { | ||||
| auto ds = std::make_shared<ProjectDataset>(columns); | auto ds = std::make_shared<ProjectDataset>(columns); | ||||
| @@ -400,6 +415,28 @@ bool ShuffleDataset::ValidateParams() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| // Constructor for SkipDataset | |||||
| SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {} | |||||
| // Function to build the SkipOp | |||||
| std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> SkipDataset::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_)); | |||||
| return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); | |||||
| } | |||||
| // Function to validate the parameters for SkipDataset | |||||
| bool SkipDataset::ValidateParams() { | |||||
| if (skip_count_ <= -1) { | |||||
| MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| // Constructor for Cifar10Dataset | // Constructor for Cifar10Dataset | ||||
| Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler) | Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler) | ||||
| : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} | : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} | ||||
| @@ -46,6 +46,7 @@ class BatchDataset; | |||||
| class RepeatDataset; | class RepeatDataset; | ||||
| class MapDataset; | class MapDataset; | ||||
| class ShuffleDataset; | class ShuffleDataset; | ||||
| class SkipDataset; | |||||
| class Cifar10Dataset; | class Cifar10Dataset; | ||||
| class ProjectDataset; | class ProjectDataset; | ||||
| class ZipDataset; | class ZipDataset; | ||||
| @@ -160,6 +161,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// \return Shared pointer to the current ShuffleDataset | /// \return Shared pointer to the current ShuffleDataset | ||||
| std::shared_ptr<ShuffleDataset> Shuffle(int32_t shuffle_size); | std::shared_ptr<ShuffleDataset> Shuffle(int32_t shuffle_size); | ||||
| /// \brief Function to create a SkipDataset | |||||
| /// \notes Skips count elements in this dataset. | |||||
| /// \param[in] count Number of elements the dataset to be skipped. | |||||
| /// \return Shared pointer to the current SkipDataset | |||||
| std::shared_ptr<SkipDataset> Skip(int32_t count); | |||||
| /// \brief Function to create a Project Dataset | /// \brief Function to create a Project Dataset | ||||
| /// \notes Applies project to the dataset | /// \notes Applies project to the dataset | ||||
| /// \param[in] columns The name of columns to project | /// \param[in] columns The name of columns to project | ||||
| @@ -293,6 +300,26 @@ class ShuffleDataset : public Dataset { | |||||
| bool reset_every_epoch_; | bool reset_every_epoch_; | ||||
| }; | }; | ||||
| class SkipDataset : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit SkipDataset(int32_t count); | |||||
| /// \brief Destructor | |||||
| ~SkipDataset() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return bool true if all the params are valid | |||||
| bool ValidateParams() override; | |||||
| private: | |||||
| int32_t skip_count_; | |||||
| }; | |||||
| class MapDataset : public Dataset { | class MapDataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -2094,8 +2094,8 @@ class SkipDataset(DatasetOp): | |||||
| The result of applying Skip operator to the input Dataset. | The result of applying Skip operator to the input Dataset. | ||||
| Args: | Args: | ||||
| input_dataset (tuple): A tuple of datasets to be skipped. | |||||
| count (int): Number of rows the dataset should be skipped. | |||||
| input_dataset (Dataset): Input dataset to have rows skipped. | |||||
| count (int): Number of rows in the dataset to be skipped. | |||||
| """ | """ | ||||
| def __init__(self, input_dataset, count): | def __init__(self, input_dataset, count): | ||||
| @@ -573,6 +573,59 @@ TEST_F(MindDataTestPipeline, TestShuffleDataset) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestSkipDataset) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDataset."; | |||||
| // Create an ImageFolder Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| EXPECT_TRUE(ds != nullptr); | |||||
| // Create a Skip operation on ds | |||||
| int32_t count = 3; | |||||
| ds = ds->Skip(count); | |||||
| EXPECT_TRUE(ds != nullptr); | |||||
| // Create an iterator over the result of the above dataset | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_TRUE(iter != nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| MS_LOG(INFO) << "Number of rows: " << i; | |||||
| // Expect 10-3=7 rows | |||||
| EXPECT_TRUE(i == 7); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestSkipDatasetError1) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1."; | |||||
| // Create an ImageFolder Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| EXPECT_TRUE(ds != nullptr); | |||||
| // Create a Skip operation on ds with invalid count input | |||||
| int32_t count = -1; | |||||
| ds = ds->Skip(count); | |||||
| // Expect nullptr for invalid input skip_count | |||||
| EXPECT_TRUE(ds == nullptr); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestCifar10Dataset) { | TEST_F(MindDataTestPipeline, TestCifar10Dataset) { | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| @@ -13,9 +13,12 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | import mindspore.dataset.transforms.vision.c_transforms as vision | ||||
| from mindspore import log as logger | |||||
| DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | ||||
| SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | ||||
| @@ -196,6 +199,29 @@ def test_skip_filter_2(): | |||||
| assert buf == [5, 6, 7, 8, 9, 10] | assert buf == [5, 6, 7, 8, 9, 10] | ||||
| def test_skip_exception_1(): | |||||
| data1 = ds.GeneratorDataset(generator_md, ["data"]) | |||||
| try: | |||||
| data1 = data1.skip(count=-1) | |||||
| num_iter = 0 | |||||
| for _ in data1.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| except RuntimeError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "Skip count must be positive integer or 0." in str(e) | |||||
| def test_skip_exception_2(): | |||||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||||
| with pytest.raises(ValueError) as e: | |||||
| ds1 = ds1.skip(-2) | |||||
| assert "Input count is not within the required interval" in str(e.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_tf_skip() | test_tf_skip() | ||||
| test_generator_skip() | test_generator_skip() | ||||
| @@ -208,3 +234,5 @@ if __name__ == "__main__": | |||||
| test_skip_take_2() | test_skip_take_2() | ||||
| test_skip_filter_1() | test_skip_filter_1() | ||||
| test_skip_filter_2() | test_skip_filter_2() | ||||
| test_skip_exception_1() | |||||
| test_skip_exception_2() | |||||