Browse Source

!4366 Fixing param type in c++ api and fixing python validator for Repeat Op

Merge pull request !4366 from TinaMengtingZhang/cpp-api-repeat-count
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ed1244558c
6 changed files with 126 additions and 8 deletions
  1. +3
    -3
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +2
    -2
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  3. +1
    -1
      mindspore/dataset/engine/datasets.py
  4. +2
    -1
      mindspore/dataset/engine/validators.py
  5. +95
    -0
      tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
  6. +23
    -1
      tests/ut/python/dataset/test_repeat.py

+ 3
- 3
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -1165,7 +1165,7 @@ std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
return node_ops;
}

RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
RepeatDataset::RepeatDataset(int32_t count) : repeat_count_(count) {}

std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
@@ -1176,8 +1176,8 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
}

bool RepeatDataset::ValidateParams() {
if (repeat_count_ <= 0) {
MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative";
if (repeat_count_ != -1 && repeat_count_ <= 0) {
MS_LOG(ERROR) << "Repeat: Repeat count cannot be" << repeat_count_;
return false;
}



+ 2
- 2
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -692,7 +692,7 @@ class RenameDataset : public Dataset {
class RepeatDataset : public Dataset {
public:
/// \brief Constructor
explicit RepeatDataset(uint32_t count);
explicit RepeatDataset(int32_t count);

/// \brief Destructor
~RepeatDataset() = default;
@@ -706,7 +706,7 @@ class RepeatDataset : public Dataset {
bool ValidateParams() override;

private:
uint32_t repeat_count_;
int32_t repeat_count_;
};

class ShuffleDataset : public Dataset {


+ 1
- 1
mindspore/dataset/engine/datasets.py View File

@@ -2123,7 +2123,7 @@ class RepeatDataset(DatasetOp):

Args:
input_dataset (Dataset): Input Dataset to be repeated.
count (int): Number of times the dataset should be repeated.
count (int): Number of times the dataset should be repeated (default=-1, repeat indefinitely).
"""

def __init__(self, input_dataset, count):


+ 2
- 1
mindspore/dataset/engine/validators.py View File

@@ -597,7 +597,8 @@ def check_repeat(method):

type_check(count, (int, type(None)), "repeat")
if isinstance(count, int):
check_value(count, (-1, INT32_MAX), "count")
if (count <= 0 and count != -1) or count > INT32_MAX:
raise ValueError("count should be either -1 or positive integer.")
return method(self, *args, **kwargs)

return new_method


+ 95
- 0
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc View File

@@ -431,6 +431,101 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) {
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestRepeatDefault) {
MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatDefault.";

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds,nullptr);

// Create a Repeat operation on ds
// Default value of repeat count is -1, expected to repeat infinitely
ds = ds->Repeat();
EXPECT_NE(ds,nullptr);

// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds,nullptr);

// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr <Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter,nullptr);

// Iterate the dataset and get each row
std::unordered_map <std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size()!= 0) {
// manually stop
if(i==100){break;}
i++;
auto image = row["image"];
MS_LOG(INFO)<< "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}

EXPECT_EQ(i,100);
// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestRepeatOne) {
MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatOne.";

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds,nullptr);

// Create a Repeat operation on ds
int32_t repeat_num = 1;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds,nullptr);

// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds,nullptr);

// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr <Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter,nullptr);

// Iterate the dataset and get each row
std::unordered_map <std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size()!= 0) {
i++;
auto image = row["image"];
MS_LOG(INFO)<< "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}

EXPECT_EQ(i,10);
// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestRepeatFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatFail.";
// This case is expected to fail because the repeat count is invalid (<-1 && !=0).

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);

// Create a Repeat operation on ds
int32_t repeat_num = -2;
ds = ds->Repeat(repeat_num);
EXPECT_EQ(ds, nullptr);
}

TEST_F(MindDataTestPipeline, TestShuffleDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestShuffleDataset.";



+ 23
- 1
tests/ut/python/dataset/test_repeat.py View File

@@ -16,7 +16,7 @@
Test Repeat Op
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
@@ -295,6 +295,26 @@ def test_repeat_count2():
assert data1_size == 3
assert dataset_size == num1_iter == 8

def test_repeat_count0():
"""
Test Repeat with invalid count 0.
"""
logger.info("Test Repeat with invalid count 0")
with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(0)
assert "count" in str(info)

def test_repeat_countneg2():
"""
Test Repeat with invalid count -2.
"""
logger.info("Test Repeat with invalid count -2")
with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(-2)
assert "count" in str(info)

if __name__ == "__main__":
test_tf_repeat_01()
test_tf_repeat_02()
@@ -313,3 +333,5 @@ if __name__ == "__main__":
test_nested_repeat11()
test_repeat_count1()
test_repeat_count2()
test_repeat_count0()
test_repeat_countneg2()

Loading…
Cancel
Save