Browse Source

change api to use std_vector

try to display image
tags/v0.7.0-beta
nhussain 5 years ago
parent
commit
738ae2c78d
13 changed files with 144 additions and 69 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc
  2. +10
    -7
      mindspore/ccsrc/minddata/dataset/api/transforms.cc
  3. +4
    -6
      mindspore/ccsrc/minddata/dataset/include/transforms.h
  4. +7
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.cc
  5. +3
    -6
      mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h
  6. +2
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.cc
  7. +3
    -4
      mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.h
  8. +1
    -1
      mindspore/dataset/transforms/vision/c_transforms.py
  9. +57
    -12
      tests/ut/cpp/dataset/c_api_transforms_test.cc
  10. +3
    -1
      tests/ut/cpp/dataset/random_solarize_op_test.cc
  11. +8
    -6
      tests/ut/cpp/dataset/solarize_op_test.cc
  12. BIN
      tests/ut/data/dataset/golden/random_solarize_02_result.npz
  13. +45
    -24
      tests/ut/python/dataset/test_random_solarize_op.py

+ 1
- 1
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc View File

@@ -422,7 +422,7 @@ PYBIND_REGISTER(
PYBIND_REGISTER(RandomSolarizeOp, 1, ([](const py::module *m) {
(void)py::class_<RandomSolarizeOp, TensorOp, std::shared_ptr<RandomSolarizeOp>>(*m,
"RandomSolarizeOp")
.def(py::init<uint8_t, uint8_t>());
.def(py::init<std::vector<uint8_t>>());
}));

} // namespace dataset


+ 10
- 7
mindspore/ccsrc/minddata/dataset/api/transforms.cc View File

@@ -241,8 +241,8 @@ std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degre
}

// Function to create RandomSolarizeOperation.
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min, uint8_t threshold_max) {
auto op = std::make_shared<RandomSolarizeOperation>(threshold_min, threshold_max);
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(std::vector<uint8_t> threshold) {
auto op = std::make_shared<RandomSolarizeOperation>(threshold);
// Input validation
if (!op->ValidateParams()) {
return nullptr;
@@ -811,19 +811,22 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
}

// RandomSolarizeOperation.
RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max)
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {}

bool RandomSolarizeOperation::ValidateParams() {
if (threshold_max_ < threshold_min_) {
MS_LOG(ERROR) << "RandomSolarize: threshold_max must be greater or equal to threshold_min";
if (threshold_.size() != 2) {
MS_LOG(ERROR) << "RandomSolarize: threshold vector has incorrect size: " << threshold_.size();
return false;
}
if (threshold_.at(0) > threshold_.at(1)) {
MS_LOG(ERROR) << "RandomSolarize: threshold must be passed in a min, max format";
return false;
}
return true;
}

std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_min_, threshold_max_);
std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_);
return tensor_op;
}



+ 4
- 6
mindspore/ccsrc/minddata/dataset/include/transforms.h View File

@@ -249,10 +249,9 @@ std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> deg

/// \brief Function to create a RandomSolarize TensorOperation.
/// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold
/// \param[in] threshold_min - lower limit
/// \param[in] threshold_max - upper limit
/// \param[in] threshold - a vector with two elements specifying the pixel range to invert.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min = 0, uint8_t threshold_max = 255);
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(std::vector<uint8_t> threshold = {0, 255});

/// \brief Function to create a RandomVerticalFlip TensorOperation.
/// \notes Tensor operation to perform random vertical flip.
@@ -657,7 +656,7 @@ class SwapRedBlueOperation : public TensorOperation {

class RandomSolarizeOperation : public TensorOperation {
public:
explicit RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max);
explicit RandomSolarizeOperation(std::vector<uint8_t> threshold);

~RandomSolarizeOperation() = default;

@@ -666,8 +665,7 @@ class RandomSolarizeOperation : public TensorOperation {
bool ValidateParams() override;

private:
uint8_t threshold_min_;
uint8_t threshold_max_;
std::vector<uint8_t> threshold_;
};
} // namespace vision
} // namespace api


+ 7
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.cc View File

@@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>

#include "minddata/dataset/kernels/image/random_solarize_op.h"
#include "minddata/dataset/kernels/image/solarize_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
@@ -24,6 +26,9 @@ namespace dataset {

Status RandomSolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);

uint8_t threshold_min_ = threshold_[0], threshold_max_ = threshold_[1];

CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_,
"threshold_min must be smaller or equal to threshold_max.");

@@ -35,7 +40,8 @@ Status RandomSolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
threshold_min = threshold_max;
threshold_max = temp;
}
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold_min, threshold_max));
std::vector<uint8_t> inputs = {threshold_min, threshold_max};
std::unique_ptr<SolarizeOp> op(new SolarizeOp(inputs));
return op->Compute(input, output);
}
} // namespace dataset


+ 3
- 6
mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h View File

@@ -19,6 +19,7 @@

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
@@ -31,10 +32,7 @@ namespace dataset {
class RandomSolarizeOp : public SolarizeOp {
public:
// Pick a random threshold value to solarize the image with
explicit RandomSolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255)
: threshold_min_(threshold_min), threshold_max_(threshold_max) {
rnd_.seed(GetSeed());
}
explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); }

~RandomSolarizeOp() = default;

@@ -43,8 +41,7 @@ class RandomSolarizeOp : public SolarizeOp {
std::string Name() const override { return kRandomSolarizeOp; }

private:
uint8_t threshold_min_;
uint8_t threshold_max_;
std::vector<uint8_t> threshold_;
std::mt19937 rnd_;
};
} // namespace dataset


+ 2
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.cc View File

@@ -27,6 +27,8 @@ const uint8_t kPixelValue = 255;
Status SolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);

uint8_t threshold_min_ = threshold_[0], threshold_max_ = threshold_[1];

CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_,
"threshold_min must be smaller or equal to threshold_max.");



+ 3
- 4
mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.h View File

@@ -19,6 +19,7 @@

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
@@ -28,8 +29,7 @@ namespace mindspore {
namespace dataset {
class SolarizeOp : public TensorOp {
public:
explicit SolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255)
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}
explicit SolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) {}

~SolarizeOp() = default;

@@ -38,8 +38,7 @@ class SolarizeOp : public TensorOp {
std::string Name() const override { return kSolarizeOp; }

private:
uint8_t threshold_min_;
uint8_t threshold_max_;
std::vector<uint8_t> threshold_;
};
} // namespace dataset
} // namespace mindspore


+ 1
- 1
mindspore/dataset/transforms/vision/c_transforms.py View File

@@ -1045,4 +1045,4 @@ class RandomSolarize(cde.RandomSolarizeOp):

@check_random_solarize
def __init__(self, threshold=(0, 255)):
super().__init__(*threshold)
super().__init__(threshold)

+ 57
- 12
tests/ut/cpp/dataset/c_api_transforms_test.cc View File

@@ -1109,29 +1109,59 @@ TEST_F(MindDataTestPipeline, TestUniformAugWithOps) {
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestRandomSolarize) {
TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize.";

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);

// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);

// Create objects for the tensor ops
std::shared_ptr<TensorOperation> random_solarize =
mindspore::dataset::api::vision::RandomSolarize(23, 23); // vision::RandomSolarize();
std::vector<uint8_t> threshold = {10, 100};
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold);
EXPECT_NE(random_solarize, nullptr);

// Create a Map operation on ds
ds = ds->Map({random_solarize});
EXPECT_NE(ds, nullptr);

// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);

// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);

uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}

EXPECT_EQ(i, 10);

// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize with default params.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);

// Create objects for the tensor ops
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize();
EXPECT_NE(random_solarize, nullptr);

// Create a Map operation on ds
ds = ds->Map({random_solarize});
EXPECT_NE(ds, nullptr);

// Create an iterator over the result of the above dataset
@@ -1151,8 +1181,23 @@ TEST_F(MindDataTestPipeline, TestRandomSolarize) {
iter->GetNextRow(&row);
}

EXPECT_EQ(i, 20);
EXPECT_EQ(i, 10);

// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestRandomSolarizeFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize with invalid params.";
std::vector<uint8_t> threshold = {13, 1};
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold);
EXPECT_EQ(random_solarize, nullptr);

threshold = {1, 2, 3};
random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold);
EXPECT_EQ(random_solarize, nullptr);

threshold = {1};
random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold);
EXPECT_EQ(random_solarize, nullptr);
}

+ 3
- 1
tests/ut/cpp/dataset/random_solarize_op_test.cc View File

@@ -38,7 +38,9 @@ TEST_F(MindDataTestRandomSolarizeOp, TestOp1) {
uint32_t curr_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0);

std::unique_ptr<RandomSolarizeOp> op(new RandomSolarizeOp(100, 100));
std::vector<uint8_t> threshold = {100, 100};
std::unique_ptr<RandomSolarizeOp> op(new RandomSolarizeOp(threshold));

EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor_);



+ 8
- 6
tests/ut/cpp/dataset/solarize_op_test.cc View File

@@ -74,7 +74,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp3) {
MS_LOG(INFO) << "Doing testSolarizeOp3 - Pass in only threshold_min parameter";

// unsigned int threshold = 128;
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1));
std::vector<uint8_t> threshold ={1, 255};
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold));

std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 0};
@@ -98,8 +99,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp3) {
TEST_F(MindDataTestSolarizeOp, TestOp4) {
MS_LOG(INFO) << "Doing testSolarizeOp4 - Pass in both threshold parameters.";

// unsigned int threshold = 128;
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1, 230));
std::vector<uint8_t> threshold ={1, 230};
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold));

std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 255};
@@ -123,8 +124,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp4) {
TEST_F(MindDataTestSolarizeOp, TestOp5) {
MS_LOG(INFO) << "Doing testSolarizeOp5 - Rank 2 input tensor.";

// unsigned int threshold = 128;
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1, 230));
std::vector<uint8_t> threshold ={1, 230};
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold));

std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 255};
@@ -149,7 +150,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp5) {
TEST_F(MindDataTestSolarizeOp, TestOp6) {
MS_LOG(INFO) << "Doing testSolarizeOp6 - Bad Input.";

std::unique_ptr<SolarizeOp> op(new SolarizeOp(10, 1));
std::vector<uint8_t> threshold ={10, 1};
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold));

std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
std::shared_ptr<Tensor> test_input_tensor;


BIN
tests/ut/data/dataset/golden/random_solarize_02_result.npz View File


+ 45
- 24
tests/ut/python/dataset/test_random_solarize_op.py View File

@@ -17,66 +17,87 @@ Testing RandomSolarizeOp op in DE
"""
import pytest
import mindspore.dataset as ds
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers
from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers, \
visualize_one_channel_dataset

GENERATE_GOLDEN = False

MNIST_DATA_DIR = "../data/dataset/testMnistData"
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"


def test_random_solarize_op(threshold=None, plot=False):
def test_random_solarize_op(threshold=(10, 150), plot=False, run_golden=True):
"""
Test RandomSolarize
"""
logger.info("Test RandomSolarize")

# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = vision.Decode()

original_seed = config_get_set_seed(0)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)

if threshold is None:
solarize_op = vision.RandomSolarize()
else:
solarize_op = vision.RandomSolarize(threshold)

data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=solarize_op)

# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(input_columns=["image"], operations=decode_op)

if run_golden:
filename = "random_solarize_01_result.npz"
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)

image_solarized = []
image = []

for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
image_solarized.append(item1["image"].copy())
image.append(item2["image"].copy())
if plot:
visualize_list(image, image_solarized)

ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_num_parallel_workers)


def test_random_solarize_md5():
def test_random_solarize_mnist(plot=False, run_golden=True):
"""
Test RandomSolarize
Test RandomSolarize op with MNIST dataset (Grayscale images)
"""
logger.info("Test RandomSolarize")
original_seed = config_get_set_seed(0)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)

data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = vision.Decode()
random_solarize_op = vision.RandomSolarize((10, 150))
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=random_solarize_op)
# Compare with expected md5 from images
filename = "random_solarize_01_result.npz"
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
mnist_1 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
mnist_2 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
mnist_2 = mnist_2.map(input_columns="image", operations=vision.RandomSolarize((0, 255)))

# Restore config setting
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_num_parallel_workers)
images = []
images_trans = []
labels = []

for _, (data_orig, data_trans) in enumerate(zip(mnist_1, mnist_2)):
image_orig, label_orig = data_orig
image_trans, _ = data_trans
images.append(image_orig)
labels.append(label_orig)
images_trans.append(image_trans)

if plot:
visualize_one_channel_dataset(images, images_trans, labels)

if run_golden:
filename = "random_solarize_02_result.npz"
save_and_check_md5(mnist_2, filename, generate_golden=GENERATE_GOLDEN)


def test_random_solarize_errors():
@@ -105,8 +126,8 @@ def test_random_solarize_errors():


if __name__ == "__main__":
test_random_solarize_op((100, 100), plot=True)
test_random_solarize_op((12, 120), plot=True)
test_random_solarize_op(plot=True)
test_random_solarize_op((10, 150), plot=True, run_golden=True)
test_random_solarize_op((12, 120), plot=True, run_golden=False)
test_random_solarize_op(plot=True, run_golden=False)
test_random_solarize_mnist(plot=True, run_golden=True)
test_random_solarize_errors()
test_random_solarize_md5()

Loading…
Cancel
Save