| @@ -422,7 +422,7 @@ PYBIND_REGISTER( | |||
| PYBIND_REGISTER(RandomSolarizeOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<RandomSolarizeOp, TensorOp, std::shared_ptr<RandomSolarizeOp>>(*m, | |||
| "RandomSolarizeOp") | |||
| .def(py::init<uint8_t, uint8_t>()); | |||
| .def(py::init<std::vector<uint8_t>>()); | |||
| })); | |||
| } // namespace dataset | |||
| @@ -241,8 +241,8 @@ std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degre | |||
| } | |||
| // Function to create RandomSolarizeOperation. | |||
| std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min, uint8_t threshold_max) { | |||
| auto op = std::make_shared<RandomSolarizeOperation>(threshold_min, threshold_max); | |||
| std::shared_ptr<RandomSolarizeOperation> RandomSolarize(std::vector<uint8_t> threshold) { | |||
| auto op = std::make_shared<RandomSolarizeOperation>(threshold); | |||
| // Input validation | |||
| if (!op->ValidateParams()) { | |||
| return nullptr; | |||
| @@ -811,19 +811,22 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() { | |||
| } | |||
| // RandomSolarizeOperation. | |||
| RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max) | |||
| : threshold_min_(threshold_min), threshold_max_(threshold_max) {} | |||
| RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {} | |||
| bool RandomSolarizeOperation::ValidateParams() { | |||
| if (threshold_max_ < threshold_min_) { | |||
| MS_LOG(ERROR) << "RandomSolarize: threshold_max must be greater or equal to threshold_min"; | |||
| if (threshold_.size() != 2) { | |||
| MS_LOG(ERROR) << "RandomSolarize: threshold vector has incorrect size: " << threshold_.size(); | |||
| return false; | |||
| } | |||
| if (threshold_.at(0) > threshold_.at(1)) { | |||
| MS_LOG(ERROR) << "RandomSolarize: threshold must be passed in a min, max format"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() { | |||
| std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_min_, threshold_max_); | |||
| std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_); | |||
| return tensor_op; | |||
| } | |||
| @@ -249,10 +249,9 @@ std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> deg | |||
| /// \brief Function to create a RandomSolarize TensorOperation. | |||
| /// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold | |||
| /// \param[in] threshold_min - lower limit | |||
| /// \param[in] threshold_max - upper limit | |||
| /// \param[in] threshold - a vector with two elements specifying the pixel range to invert. | |||
| /// \return Shared pointer to the current TensorOperation. | |||
| std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min = 0, uint8_t threshold_max = 255); | |||
| std::shared_ptr<RandomSolarizeOperation> RandomSolarize(std::vector<uint8_t> threshold = {0, 255}); | |||
| /// \brief Function to create a RandomVerticalFlip TensorOperation. | |||
| /// \notes Tensor operation to perform random vertical flip. | |||
| @@ -657,7 +656,7 @@ class SwapRedBlueOperation : public TensorOperation { | |||
| class RandomSolarizeOperation : public TensorOperation { | |||
| public: | |||
| explicit RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max); | |||
| explicit RandomSolarizeOperation(std::vector<uint8_t> threshold); | |||
| ~RandomSolarizeOperation() = default; | |||
| @@ -666,8 +665,7 @@ class RandomSolarizeOperation : public TensorOperation { | |||
| bool ValidateParams() override; | |||
| private: | |||
| uint8_t threshold_min_; | |||
| uint8_t threshold_max_; | |||
| std::vector<uint8_t> threshold_; | |||
| }; | |||
| } // namespace vision | |||
| } // namespace api | |||
| @@ -13,6 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "minddata/dataset/kernels/image/random_solarize_op.h" | |||
| #include "minddata/dataset/kernels/image/solarize_op.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| @@ -24,6 +26,9 @@ namespace dataset { | |||
| Status RandomSolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| uint8_t threshold_min_ = threshold_[0], threshold_max_ = threshold_[1]; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_, | |||
| "threshold_min must be smaller or equal to threshold_max."); | |||
| @@ -35,7 +40,8 @@ Status RandomSolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shar | |||
| threshold_min = threshold_max; | |||
| threshold_max = temp; | |||
| } | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold_min, threshold_max)); | |||
| std::vector<uint8_t> inputs = {threshold_min, threshold_max}; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(inputs)); | |||
| return op->Compute(input, output); | |||
| } | |||
| } // namespace dataset | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| @@ -31,10 +32,7 @@ namespace dataset { | |||
| class RandomSolarizeOp : public SolarizeOp { | |||
| public: | |||
| // Pick a random threshold value to solarize the image with | |||
| explicit RandomSolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255) | |||
| : threshold_min_(threshold_min), threshold_max_(threshold_max) { | |||
| rnd_.seed(GetSeed()); | |||
| } | |||
| explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); } | |||
| ~RandomSolarizeOp() = default; | |||
| @@ -43,8 +41,7 @@ class RandomSolarizeOp : public SolarizeOp { | |||
| std::string Name() const override { return kRandomSolarizeOp; } | |||
| private: | |||
| uint8_t threshold_min_; | |||
| uint8_t threshold_max_; | |||
| std::vector<uint8_t> threshold_; | |||
| std::mt19937 rnd_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -27,6 +27,8 @@ const uint8_t kPixelValue = 255; | |||
| Status SolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| uint8_t threshold_min_ = threshold_[0], threshold_max_ = threshold_[1]; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_, | |||
| "threshold_min must be smaller or equal to threshold_max."); | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| @@ -28,8 +29,7 @@ namespace mindspore { | |||
| namespace dataset { | |||
| class SolarizeOp : public TensorOp { | |||
| public: | |||
| explicit SolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255) | |||
| : threshold_min_(threshold_min), threshold_max_(threshold_max) {} | |||
| explicit SolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) {} | |||
| ~SolarizeOp() = default; | |||
| @@ -38,8 +38,7 @@ class SolarizeOp : public TensorOp { | |||
| std::string Name() const override { return kSolarizeOp; } | |||
| private: | |||
| uint8_t threshold_min_; | |||
| uint8_t threshold_max_; | |||
| std::vector<uint8_t> threshold_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1045,4 +1045,4 @@ class RandomSolarize(cde.RandomSolarizeOp): | |||
| @check_random_solarize | |||
| def __init__(self, threshold=(0, 255)): | |||
| super().__init__(*threshold) | |||
| super().__init__(threshold) | |||
| @@ -1109,29 +1109,59 @@ TEST_F(MindDataTestPipeline, TestUniformAugWithOps) { | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestRandomSolarize) { | |||
| TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess1) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create a Repeat operation on ds | |||
| int32_t repeat_num = 2; | |||
| ds = ds->Repeat(repeat_num); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorOperation> random_solarize = | |||
| mindspore::dataset::api::vision::RandomSolarize(23, 23); // vision::RandomSolarize(); | |||
| std::vector<uint8_t> threshold = {10, 100}; | |||
| std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); | |||
| EXPECT_NE(random_solarize, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({random_solarize}); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create a Batch operation on ds | |||
| int32_t batch_size = 1; | |||
| ds = ds->Batch(batch_size); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(i, 10); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess2) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize with default params."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(); | |||
| EXPECT_NE(random_solarize, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({random_solarize}); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| @@ -1151,8 +1181,23 @@ TEST_F(MindDataTestPipeline, TestRandomSolarize) { | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(i, 20); | |||
| EXPECT_EQ(i, 10); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestRandomSolarizeFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize with invalid params."; | |||
| std::vector<uint8_t> threshold = {13, 1}; | |||
| std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); | |||
| EXPECT_EQ(random_solarize, nullptr); | |||
| threshold = {1, 2, 3}; | |||
| random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); | |||
| EXPECT_EQ(random_solarize, nullptr); | |||
| threshold = {1}; | |||
| random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); | |||
| EXPECT_EQ(random_solarize, nullptr); | |||
| } | |||
| @@ -38,7 +38,9 @@ TEST_F(MindDataTestRandomSolarizeOp, TestOp1) { | |||
| uint32_t curr_seed = GlobalContext::config_manager()->seed(); | |||
| GlobalContext::config_manager()->set_seed(0); | |||
| std::unique_ptr<RandomSolarizeOp> op(new RandomSolarizeOp(100, 100)); | |||
| std::vector<uint8_t> threshold = {100, 100}; | |||
| std::unique_ptr<RandomSolarizeOp> op(new RandomSolarizeOp(threshold)); | |||
| EXPECT_TRUE(op->OneToOne()); | |||
| Status s = op->Compute(input_tensor_, &output_tensor_); | |||
| @@ -74,7 +74,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp3) { | |||
| MS_LOG(INFO) << "Doing testSolarizeOp3 - Pass in only threshold_min parameter"; | |||
| // unsigned int threshold = 128; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(1)); | |||
| std::vector<uint8_t> threshold ={1, 255}; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold)); | |||
| std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255}; | |||
| std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 0}; | |||
| @@ -98,8 +99,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp3) { | |||
| TEST_F(MindDataTestSolarizeOp, TestOp4) { | |||
| MS_LOG(INFO) << "Doing testSolarizeOp4 - Pass in both threshold parameters."; | |||
| // unsigned int threshold = 128; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(1, 230)); | |||
| std::vector<uint8_t> threshold ={1, 230}; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold)); | |||
| std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255}; | |||
| std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 255}; | |||
| @@ -123,8 +124,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp4) { | |||
| TEST_F(MindDataTestSolarizeOp, TestOp5) { | |||
| MS_LOG(INFO) << "Doing testSolarizeOp5 - Rank 2 input tensor."; | |||
| // unsigned int threshold = 128; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(1, 230)); | |||
| std::vector<uint8_t> threshold ={1, 230}; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold)); | |||
| std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255}; | |||
| std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 255}; | |||
| @@ -149,7 +150,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp5) { | |||
| TEST_F(MindDataTestSolarizeOp, TestOp6) { | |||
| MS_LOG(INFO) << "Doing testSolarizeOp6 - Bad Input."; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(10, 1)); | |||
| std::vector<uint8_t> threshold ={10, 1}; | |||
| std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold)); | |||
| std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255}; | |||
| std::shared_ptr<Tensor> test_input_tensor; | |||
| @@ -17,66 +17,87 @@ Testing RandomSolarizeOp op in DE | |||
| """ | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.engine as de | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers | |||
| from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers, \ | |||
| visualize_one_channel_dataset | |||
| GENERATE_GOLDEN = False | |||
| MNIST_DATA_DIR = "../data/dataset/testMnistData" | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_random_solarize_op(threshold=None, plot=False): | |||
| def test_random_solarize_op(threshold=(10, 150), plot=False, run_golden=True): | |||
| """ | |||
| Test RandomSolarize | |||
| """ | |||
| logger.info("Test RandomSolarize") | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| decode_op = vision.Decode() | |||
| original_seed = config_get_set_seed(0) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| if threshold is None: | |||
| solarize_op = vision.RandomSolarize() | |||
| else: | |||
| solarize_op = vision.RandomSolarize(threshold) | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| data1 = data1.map(input_columns=["image"], operations=solarize_op) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(input_columns=["image"], operations=decode_op) | |||
| if run_golden: | |||
| filename = "random_solarize_01_result.npz" | |||
| save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) | |||
| image_solarized = [] | |||
| image = [] | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image_solarized.append(item1["image"].copy()) | |||
| image.append(item2["image"].copy()) | |||
| if plot: | |||
| visualize_list(image, image_solarized) | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_random_solarize_md5(): | |||
| def test_random_solarize_mnist(plot=False, run_golden=True): | |||
| """ | |||
| Test RandomSolarize | |||
| Test RandomSolarize op with MNIST dataset (Grayscale images) | |||
| """ | |||
| logger.info("Test RandomSolarize") | |||
| original_seed = config_get_set_seed(0) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| decode_op = vision.Decode() | |||
| random_solarize_op = vision.RandomSolarize((10, 150)) | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| data1 = data1.map(input_columns=["image"], operations=random_solarize_op) | |||
| # Compare with expected md5 from images | |||
| filename = "random_solarize_01_result.npz" | |||
| save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) | |||
| mnist_1 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) | |||
| mnist_2 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) | |||
| mnist_2 = mnist_2.map(input_columns="image", operations=vision.RandomSolarize((0, 255))) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| images = [] | |||
| images_trans = [] | |||
| labels = [] | |||
| for _, (data_orig, data_trans) in enumerate(zip(mnist_1, mnist_2)): | |||
| image_orig, label_orig = data_orig | |||
| image_trans, _ = data_trans | |||
| images.append(image_orig) | |||
| labels.append(label_orig) | |||
| images_trans.append(image_trans) | |||
| if plot: | |||
| visualize_one_channel_dataset(images, images_trans, labels) | |||
| if run_golden: | |||
| filename = "random_solarize_02_result.npz" | |||
| save_and_check_md5(mnist_2, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_random_solarize_errors(): | |||
| @@ -105,8 +126,8 @@ def test_random_solarize_errors(): | |||
| if __name__ == "__main__": | |||
| test_random_solarize_op((100, 100), plot=True) | |||
| test_random_solarize_op((12, 120), plot=True) | |||
| test_random_solarize_op(plot=True) | |||
| test_random_solarize_op((10, 150), plot=True, run_golden=True) | |||
| test_random_solarize_op((12, 120), plot=True, run_golden=False) | |||
| test_random_solarize_op(plot=True, run_golden=False) | |||
| test_random_solarize_mnist(plot=True, run_golden=True) | |||
| test_random_solarize_errors() | |||
| test_random_solarize_md5() | |||