Merge pull request !4366 from TinaMengtingZhang/cpp-api-repeat-counttags/v0.7.0-beta
| @@ -1165,7 +1165,7 @@ std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() { | |||
| return node_ops; | |||
| } | |||
| RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {} | |||
| RepeatDataset::RepeatDataset(int32_t count) : repeat_count_(count) {} | |||
| std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| @@ -1176,8 +1176,8 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() { | |||
| } | |||
| bool RepeatDataset::ValidateParams() { | |||
| if (repeat_count_ <= 0) { | |||
| MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative"; | |||
| if (repeat_count_ != -1 && repeat_count_ <= 0) { | |||
| MS_LOG(ERROR) << "Repeat: Repeat count cannot be" << repeat_count_; | |||
| return false; | |||
| } | |||
| @@ -692,7 +692,7 @@ class RenameDataset : public Dataset { | |||
| class RepeatDataset : public Dataset { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit RepeatDataset(uint32_t count); | |||
| explicit RepeatDataset(int32_t count); | |||
| /// \brief Destructor | |||
| ~RepeatDataset() = default; | |||
| @@ -706,7 +706,7 @@ class RepeatDataset : public Dataset { | |||
| bool ValidateParams() override; | |||
| private: | |||
| uint32_t repeat_count_; | |||
| int32_t repeat_count_; | |||
| }; | |||
| class ShuffleDataset : public Dataset { | |||
| @@ -2123,7 +2123,7 @@ class RepeatDataset(DatasetOp): | |||
| Args: | |||
| input_dataset (Dataset): Input Dataset to be repeated. | |||
| count (int): Number of times the dataset should be repeated. | |||
| count (int): Number of times the dataset should be repeated (default=-1, repeat indefinitely). | |||
| """ | |||
| def __init__(self, input_dataset, count): | |||
| @@ -597,7 +597,8 @@ def check_repeat(method): | |||
| type_check(count, (int, type(None)), "repeat") | |||
| if isinstance(count, int): | |||
| check_value(count, (-1, INT32_MAX), "count") | |||
| if (count <= 0 and count != -1) or count > INT32_MAX: | |||
| raise ValueError("count should be either -1 or positive integer.") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -431,6 +431,101 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) { | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestRepeatDefault) { | |||
| MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatDefault."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds,nullptr); | |||
| // Create a Repeat operation on ds | |||
| // Default value of repeat count is -1, expected to repeat infinitely | |||
| ds = ds->Repeat(); | |||
| EXPECT_NE(ds,nullptr); | |||
| // Create a Batch operation on ds | |||
| int32_t batch_size = 1; | |||
| ds = ds->Batch(batch_size); | |||
| EXPECT_NE(ds,nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr <Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter,nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map <std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size()!= 0) { | |||
| // manually stop | |||
| if(i==100){break;} | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO)<< "Tensor image shape: " << image->shape(); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(i,100); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestRepeatOne) { | |||
| MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatOne."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds,nullptr); | |||
| // Create a Repeat operation on ds | |||
| int32_t repeat_num = 1; | |||
| ds = ds->Repeat(repeat_num); | |||
| EXPECT_NE(ds,nullptr); | |||
| // Create a Batch operation on ds | |||
| int32_t batch_size = 1; | |||
| ds = ds->Batch(batch_size); | |||
| EXPECT_NE(ds,nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr <Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter,nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map <std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size()!= 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO)<< "Tensor image shape: " << image->shape(); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(i,10); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestRepeatFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatFail."; | |||
| // This case is expected to fail because the repeat count is invalid (<-1 && !=0). | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create a Repeat operation on ds | |||
| int32_t repeat_num = -2; | |||
| ds = ds->Repeat(repeat_num); | |||
| EXPECT_EQ(ds, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestShuffleDataset) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestShuffleDataset."; | |||
| @@ -16,7 +16,7 @@ | |||
| Test Repeat Op | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| @@ -295,6 +295,26 @@ def test_repeat_count2(): | |||
| assert data1_size == 3 | |||
| assert dataset_size == num1_iter == 8 | |||
| def test_repeat_count0(): | |||
| """ | |||
| Test Repeat with invalid count 0. | |||
| """ | |||
| logger.info("Test Repeat with invalid count 0") | |||
| with pytest.raises(ValueError) as info: | |||
| data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) | |||
| data1.repeat(0) | |||
| assert "count" in str(info) | |||
| def test_repeat_countneg2(): | |||
| """ | |||
| Test Repeat with invalid count -2. | |||
| """ | |||
| logger.info("Test Repeat with invalid count -2") | |||
| with pytest.raises(ValueError) as info: | |||
| data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) | |||
| data1.repeat(-2) | |||
| assert "count" in str(info) | |||
| if __name__ == "__main__": | |||
| test_tf_repeat_01() | |||
| test_tf_repeat_02() | |||
| @@ -313,3 +333,5 @@ if __name__ == "__main__": | |||
| test_nested_repeat11() | |||
| test_repeat_count1() | |||
| test_repeat_count2() | |||
| test_repeat_count0() | |||
| test_repeat_countneg2() | |||