Browse Source

!3342 C++ API Support for Skip Dataset Op plus UTs

Merge pull request !3342 from cathwong/ckw_c_api_skip
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
859fe6bc41
5 changed files with 147 additions and 2 deletions
  1. +37
    -0
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +27
    -0
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  3. +2
    -2
      mindspore/dataset/engine/datasets.py
  4. +53
    -0
      tests/ut/cpp/dataset/c_api_test.cc
  5. +28
    -0
      tests/ut/python/dataset/test_skip.py

+ 37
- 0
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -27,6 +27,7 @@
#include "minddata/dataset/engine/datasetops/map_op.h" #include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h" #include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h" #include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
@@ -173,6 +174,20 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
return ds; return ds;
} }


// Function to create a SkipDataset.
std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
auto ds = std::make_shared<SkipDataset>(count);

// Call derived class validation method.
if (!ds->ValidateParams()) {
return nullptr;
}

ds->children.push_back(shared_from_this());

return ds;
}

// Function to create a ProjectDataset. // Function to create a ProjectDataset.
std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) { std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
auto ds = std::make_shared<ProjectDataset>(columns); auto ds = std::make_shared<ProjectDataset>(columns);
@@ -400,6 +415,28 @@ bool ShuffleDataset::ValidateParams() {
return true; return true;
} }


// Constructor for SkipDataset
SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {}

// Function to build the SkipOp
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> SkipDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
}

// Function to validate the parameters for SkipDataset
bool SkipDataset::ValidateParams() {
if (skip_count_ <= -1) {
MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_;
return false;
}

return true;
}

// Constructor for Cifar10Dataset // Constructor for Cifar10Dataset
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler) Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}


+ 27
- 0
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -46,6 +46,7 @@ class BatchDataset;
class RepeatDataset; class RepeatDataset;
class MapDataset; class MapDataset;
class ShuffleDataset; class ShuffleDataset;
class SkipDataset;
class Cifar10Dataset; class Cifar10Dataset;
class ProjectDataset; class ProjectDataset;
class ZipDataset; class ZipDataset;
@@ -160,6 +161,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current ShuffleDataset /// \return Shared pointer to the current ShuffleDataset
std::shared_ptr<ShuffleDataset> Shuffle(int32_t shuffle_size); std::shared_ptr<ShuffleDataset> Shuffle(int32_t shuffle_size);


/// \brief Function to create a SkipDataset
/// \notes Skips count elements in this dataset.
/// \param[in] count Number of elements the dataset to be skipped.
/// \return Shared pointer to the current SkipDataset
std::shared_ptr<SkipDataset> Skip(int32_t count);

/// \brief Function to create a Project Dataset /// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset /// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project /// \param[in] columns The name of columns to project
@@ -293,6 +300,26 @@ class ShuffleDataset : public Dataset {
bool reset_every_epoch_; bool reset_every_epoch_;
}; };


class SkipDataset : public Dataset {
public:
/// \brief Constructor
explicit SkipDataset(int32_t count);

/// \brief Destructor
~SkipDataset() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;

/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;

private:
int32_t skip_count_;
};

class MapDataset : public Dataset { class MapDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor


+ 2
- 2
mindspore/dataset/engine/datasets.py View File

@@ -2094,8 +2094,8 @@ class SkipDataset(DatasetOp):
The result of applying Skip operator to the input Dataset. The result of applying Skip operator to the input Dataset.


Args: Args:
input_dataset (tuple): A tuple of datasets to be skipped.
count (int): Number of rows the dataset should be skipped.
input_dataset (Dataset): Input dataset to have rows skipped.
count (int): Number of rows in the dataset to be skipped.
""" """


def __init__(self, input_dataset, count): def __init__(self, input_dataset, count):


+ 53
- 0
tests/ut/cpp/dataset/c_api_test.cc View File

@@ -573,6 +573,59 @@ TEST_F(MindDataTestPipeline, TestShuffleDataset) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestSkipDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDataset.";

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_TRUE(ds != nullptr);

// Create a Skip operation on ds
int32_t count = 3;
ds = ds->Skip(count);
EXPECT_TRUE(ds != nullptr);

// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_TRUE(iter != nullptr);

// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);

uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
MS_LOG(INFO) << "Number of rows: " << i;

// Expect 10-3=7 rows
EXPECT_TRUE(i == 7);

// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestSkipDatasetError1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1.";

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_TRUE(ds != nullptr);

// Create a Skip operation on ds with invalid count input
int32_t count = -1;
ds = ds->Skip(count);
// Expect nullptr for invalid input skip_count
EXPECT_TRUE(ds == nullptr);
}

TEST_F(MindDataTestPipeline, TestCifar10Dataset) { TEST_F(MindDataTestPipeline, TestCifar10Dataset) {


// Create a Cifar10 Dataset // Create a Cifar10 Dataset


+ 28
- 0
tests/ut/python/dataset/test_skip.py View File

@@ -13,9 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import pytest


import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger



DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
@@ -196,6 +199,29 @@ def test_skip_filter_2():
assert buf == [5, 6, 7, 8, 9, 10] assert buf == [5, 6, 7, 8, 9, 10]




def test_skip_exception_1():
data1 = ds.GeneratorDataset(generator_md, ["data"])

try:
data1 = data1.skip(count=-1)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter += 1

except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Skip count must be positive integer or 0." in str(e)


def test_skip_exception_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])

with pytest.raises(ValueError) as e:
ds1 = ds1.skip(-2)
assert "Input count is not within the required interval" in str(e.value)



if __name__ == "__main__": if __name__ == "__main__":
test_tf_skip() test_tf_skip()
test_generator_skip() test_generator_skip()
@@ -208,3 +234,5 @@ if __name__ == "__main__":
test_skip_take_2() test_skip_take_2()
test_skip_filter_1() test_skip_filter_1()
test_skip_filter_2() test_skip_filter_2()
test_skip_exception_1()
test_skip_exception_2()

Loading…
Cancel
Save