Merge pull request !3342 from cathwong/ckw_c_api_skiptags/v0.7.0-beta
| @@ -27,6 +27,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/map_op.h" | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||
| #include "minddata/dataset/engine/datasetops/skip_op.h" | |||
| #include "minddata/dataset/engine/datasetops/project_op.h" | |||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| @@ -173,6 +174,20 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) { | |||
| return ds; | |||
| } | |||
| // Function to create a SkipDataset. | |||
| std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) { | |||
| auto ds = std::make_shared<SkipDataset>(count); | |||
| // Call derived class validation method. | |||
| if (!ds->ValidateParams()) { | |||
| return nullptr; | |||
| } | |||
| ds->children.push_back(shared_from_this()); | |||
| return ds; | |||
| } | |||
| // Function to create a ProjectDataset. | |||
| std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) { | |||
| auto ds = std::make_shared<ProjectDataset>(columns); | |||
| @@ -400,6 +415,28 @@ bool ShuffleDataset::ValidateParams() { | |||
| return true; | |||
| } | |||
| // Constructor for SkipDataset | |||
| SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {} | |||
| // Function to build the SkipOp | |||
| std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> SkipDataset::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_)); | |||
| return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops); | |||
| } | |||
| // Function to validate the parameters for SkipDataset | |||
| bool SkipDataset::ValidateParams() { | |||
| if (skip_count_ <= -1) { | |||
| MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| // Constructor for Cifar10Dataset | |||
| Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler) | |||
| : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} | |||
| @@ -46,6 +46,7 @@ class BatchDataset; | |||
| class RepeatDataset; | |||
| class MapDataset; | |||
| class ShuffleDataset; | |||
| class SkipDataset; | |||
| class Cifar10Dataset; | |||
| class ProjectDataset; | |||
| class ZipDataset; | |||
| @@ -160,6 +161,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||
| /// \return Shared pointer to the current ShuffleDataset | |||
| std::shared_ptr<ShuffleDataset> Shuffle(int32_t shuffle_size); | |||
| /// \brief Function to create a SkipDataset | |||
| /// \notes Skips count elements in this dataset. | |||
| /// \param[in] count Number of elements the dataset to be skipped. | |||
| /// \return Shared pointer to the current SkipDataset | |||
| std::shared_ptr<SkipDataset> Skip(int32_t count); | |||
| /// \brief Function to create a Project Dataset | |||
| /// \notes Applies project to the dataset | |||
| /// \param[in] columns The name of columns to project | |||
| @@ -293,6 +300,26 @@ class ShuffleDataset : public Dataset { | |||
| bool reset_every_epoch_; | |||
| }; | |||
| class SkipDataset : public Dataset { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit SkipDataset(int32_t count); | |||
| /// \brief Destructor | |||
| ~SkipDataset() = default; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override; | |||
| /// \brief Parameters validation | |||
| /// \return bool true if all the params are valid | |||
| bool ValidateParams() override; | |||
| private: | |||
| int32_t skip_count_; | |||
| }; | |||
| class MapDataset : public Dataset { | |||
| public: | |||
| /// \brief Constructor | |||
| @@ -2094,8 +2094,8 @@ class SkipDataset(DatasetOp): | |||
| The result of applying Skip operator to the input Dataset. | |||
| Args: | |||
| input_dataset (tuple): A tuple of datasets to be skipped. | |||
| count (int): Number of rows the dataset should be skipped. | |||
| input_dataset (Dataset): Input dataset to have rows skipped. | |||
| count (int): Number of rows in the dataset to be skipped. | |||
| """ | |||
| def __init__(self, input_dataset, count): | |||
| @@ -573,6 +573,59 @@ TEST_F(MindDataTestPipeline, TestShuffleDataset) { | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestSkipDataset) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDataset."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_TRUE(ds != nullptr); | |||
| // Create a Skip operation on ds | |||
| int32_t count = 3; | |||
| ds = ds->Skip(count); | |||
| EXPECT_TRUE(ds != nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_TRUE(iter != nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| MS_LOG(INFO) << "Number of rows: " << i; | |||
| // Expect 10-3=7 rows | |||
| EXPECT_TRUE(i == 7); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestSkipDatasetError1) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_TRUE(ds != nullptr); | |||
| // Create a Skip operation on ds with invalid count input | |||
| int32_t count = -1; | |||
| ds = ds->Skip(count); | |||
| // Expect nullptr for invalid input skip_count | |||
| EXPECT_TRUE(ds == nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCifar10Dataset) { | |||
| // Create a Cifar10 Dataset | |||
| @@ -13,9 +13,12 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| @@ -196,6 +199,29 @@ def test_skip_filter_2(): | |||
| assert buf == [5, 6, 7, 8, 9, 10] | |||
| def test_skip_exception_1(): | |||
| data1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| try: | |||
| data1 = data1.skip(count=-1) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Skip count must be positive integer or 0." in str(e) | |||
| def test_skip_exception_2(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| with pytest.raises(ValueError) as e: | |||
| ds1 = ds1.skip(-2) | |||
| assert "Input count is not within the required interval" in str(e.value) | |||
| if __name__ == "__main__": | |||
| test_tf_skip() | |||
| test_generator_skip() | |||
| @@ -208,3 +234,5 @@ if __name__ == "__main__": | |||
| test_skip_take_2() | |||
| test_skip_filter_1() | |||
| test_skip_filter_2() | |||
| test_skip_exception_1() | |||
| test_skip_exception_2() | |||