diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc index 50884f5cf8..eafdc8a1a6 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc @@ -32,6 +32,7 @@ #include "minddata/dataset/kernels/image/normalize_op.h" #include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/random_affine_op.h" +#include "minddata/dataset/kernels/image/random_color_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" #include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" @@ -273,6 +274,14 @@ PYBIND_REGISTER( py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); })); +PYBIND_REGISTER(RandomColorOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomColorOp", + "Tensor operation to blend an image with its grayscale version with random weights" + "Takes min and max for the range of random weights") + .def(py::init(), py::arg("min"), py::arg("max")); + })); + PYBIND_REGISTER(RandomColorAdjustOp, 1, ([](const py::module *m) { (void)py::class_>( *m, "RandomColorAdjustOp", diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index ddb9cb564a..4eba7ac05b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -27,6 +27,7 @@ #include "minddata/dataset/kernels/data/one_hot_op.h" #include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/random_affine_op.h" +#include "minddata/dataset/kernels/image/random_color_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_crop_op.h" #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" @@ -140,6 +141,21 @@ std::shared_ptr Pad(std::vector padding, std::vector RandomColor(float t_lb, float t_ub) { + auto op = std::make_shared(t_lb, t_ub); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +std::shared_ptr RandomColorOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(t_lb_, t_ub_); + return tensor_op; +} + // Function to create RandomColorAdjustOperation. std::shared_ptr RandomColorAdjust(std::vector brightness, std::vector contrast, @@ -475,6 +491,18 @@ std::shared_ptr PadOperation::Build() { return tensor_op; } +// RandomColorOperation. +RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {} + +bool RandomColorOperation::ValidateParams() { + // Do some input validation. + if (t_lb_ > t_ub_) { + MS_LOG(ERROR) << "RandomColor: lower bound must be less or equal to upper bound"; + return false; + } + return true; +} + // RandomColorAdjustOperation. RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector brightness, std::vector contrast, std::vector saturation, std::vector hue) diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h index f32d422672..b67dd87683 100644 --- a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h @@ -70,7 +70,7 @@ class CVTensor : public Tensor { /// Get a reference to the CV::Mat /// \return a reference to the internal CV::Mat - cv::Mat mat() const { return mat_; } + cv::Mat &mat() { return mat_; } /// Get a copy of the CV::Mat /// \return a copy of internal CV::Mat diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index fdf4e45f17..7b13b43d19 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -57,6 +57,7 @@ class NormalizeOperation; class OneHotOperation; class PadOperation; class RandomAffineOperation; +class RandomColorOperation; class RandomColorAdjustOperation; class RandomCropOperation; class RandomHorizontalFlipOperation; @@ -162,6 +163,14 @@ std::shared_ptr RandomAffine( InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, const std::vector &fill_value = {0, 0, 0}); +/// \brief Blends an image with its grayscale version with random weights +/// t and 1 - t generated from a given range. If the range is trivial +/// then the weights are determinate and t equals the bound of the interval +/// \param[in] t_lb lower bound on the range of random weights +/// \param[in] t_lb upper bound on the range of random weights +/// \return Shared pointer to the current TensorOp +std::shared_ptr RandomColor(float t_lb, float t_ub); + /// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image /// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} @@ -417,6 +426,21 @@ class RandomAffineOperation : public TensorOperation { std::vector fill_value_; }; +class RandomColorOperation : public TensorOperation { + public: + RandomColorOperation(float t_lb, float t_ub); + + ~RandomColorOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + float t_lb_; + float t_ub_; +}; + class RandomColorAdjustOperation : public TensorOperation { public: RandomColorAdjustOperation(std::vector brightness = {1.0, 1.0}, std::vector contrast = {1.0, 1.0}, diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index 727067c9af..20733af9c8 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -44,5 +44,6 @@ add_library(kernels-image OBJECT uniform_aug_op.cc resize_with_bbox_op.cc random_resize_with_bbox_op.cc + random_color_op.cc ) add_dependencies(kernels-image kernels-soft-dvpp-image) diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc new file mode 100644 index 0000000000..7400ab1fa1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/random_color_op.h" +#include "minddata/dataset/core/cv_tensor.h" +namespace mindspore { +namespace dataset { + +RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {} + +Status RandomColorOp::Compute(const std::shared_ptr &in, std::shared_ptr *out) { + IO_CHECK(in, out); + if (in->Rank() != 3) { + RETURN_STATUS_UNEXPECTED("image must have 3 channels"); + } + // 0.5 pixel precision assuming an 8 bit image + const auto eps = 0.00195; + const auto t = dist_(rnd_); + if (abs(t - 1.0) < eps) { + // Just return input? Can we do it given that input would otherwise get consumed in CVTensor constructor anyway? + *out = in; + return Status::OK(); + } + auto cvt_in = CVTensor::AsCVTensor(in); + auto m1 = cvt_in->mat(); + cv::Mat gray; + // gray is allocated without using the allocator + cv::cvtColor(m1, gray, cv::COLOR_RGB2GRAY); + // luminosity is not preserved, consider using weights. + cv::Mat temp[3] = {gray, gray, gray}; + cv::Mat cv_out; + cv::merge(temp, 3, cv_out); + std::shared_ptr cvt_out; + CVTensor::CreateFromMat(cv_out, &cvt_out); + if (abs(t - 0.0) < eps) { + // return grayscale + *out = std::static_pointer_cast(cvt_out); + return Status::OK(); + } + // return blended image. addWeighted takes care of overflow for uint8_t + cv::addWeighted(m1, t, cvt_out->mat(), 1 - t, 0, cvt_out->mat()); + *out = std::static_pointer_cast(cvt_out); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.h new file mode 100644 index 0000000000..37e4f02655 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { + +/// \class RandomColorOp random_color_op.h +/// \brief Blends an image with its grayscale version with random weights +/// t and 1 - t generated from a given range. +/// If the range is trivial then the weights are determinate and +/// t equals the bound of the interval +class RandomColorOp : public TensorOp { + public: + RandomColorOp() = default; + /// \brief Constructor + /// \param[in] t_lb lower bound for the random weights + /// \param[in] t_ub upper bound for the random weights + RandomColorOp(float t_lb, float t_ub); + /// \brief the main function performing computations + /// \param[in] in 2- or 3- dimensional tensor representing an image + /// \param[out] out 2- or 3- dimensional tensor representing an image + /// with the same dimensions as in + Status Compute(const std::shared_ptr &in, std::shared_ptr *out) override; + /// \brief returns the name of the op + std::string Name() const override { return kRandomColorOp; } + + private: + std::mt19937 rnd_; + std::uniform_real_distribution dist_; + float t_lb_; + float t_ub_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 1047876e32..9d7068f7cc 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -129,6 +129,7 @@ constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; constexpr char kUniformAugOp[] = "UniformAugOp"; constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp"; constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp"; +constexpr char kRandomColorOp[] = "RandomColorOp"; // text constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp"; diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index d5a2cd23c8..d7f944adda 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -46,7 +46,8 @@ import mindspore._c_dataengine as cde from .utils import Inter, Border from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ - check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \ + check_range, check_resize, check_rescale, check_pad, check_cutout, \ + check_uniform_augment_cpp, \ check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER @@ -628,6 +629,21 @@ class CenterCrop(cde.CenterCropOp): super().__init__(*size) +class RandomColor(cde.RandomColorOp): + """ + Adjust the color of the input image by a fixed or random degree. + + Args: + degrees (sequence): Range of random color adjustment degrees. + It should be in (min, max) format. If min=max, then it is a + single fixed magnitude operation (default=(0.1,1.9)). + Works with 3-channel color images. + """ + + @check_positive_degrees + def __init__(self, degrees=(0.1, 1.9)): + super().__init__(*degrees) + class RandomColorAdjust(cde.RandomColorAdjustOp): """ Randomly adjust the brightness, contrast, saturation, and hue of the input image. diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index 2fc0e7991b..a4badd0b47 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -609,21 +609,23 @@ def check_uniform_augment_py(method): def check_positive_degrees(method): - """A wrapper method to check degrees parameter in RandSharpness and RandColor""" + """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (python and cpp)""" @wraps(method) def new_method(self, *args, **kwargs): [degrees], _ = parse_user_args(method, *args, **kwargs) - if isinstance(degrees, (list, tuple)): + + if degrees is not None: + if not isinstance(degrees, (list, tuple)): + raise TypeError("degrees must be either a tuple or a list.") + type_check_list(degrees, (int, float), "degrees") if len(degrees) != 2: - raise ValueError("Degrees must be a sequence with length 2.") - for value in degrees: - check_value(value, (0., FLOAT_MAX_INTEGER)) - check_positive(degrees[0], "degrees[0]") + raise ValueError("degrees must be a sequence with length 2.") + for degree in degrees: + check_value(degree, (0, FLOAT_MAX_INTEGER)) if degrees[0] > degrees[1]: - raise ValueError("Degrees should be in (min,max) format. Got (max,min).") - else: - raise TypeError("Degrees should be a tuple or list.") + raise ValueError("degrees should be in (min,max) format. Got (max,min).") + return method(self, *args, **kwargs) return new_method @@ -698,4 +700,5 @@ def check_random_solarize(method): raise ValueError("threshold must be in min max format numbers") return method(self, *args, **kwargs) + return new_method diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 3324cd5ecd..a6b16370db 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -39,6 +39,7 @@ SET(DE_UT_SRCS project_op_test.cc queue_test.cc random_affine_op_test.cc + random_color_op_test.cc random_crop_op_test.cc random_crop_with_bbox_op_test.cc random_crop_decode_resize_op_test.cc diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index 8eae5ac453..50183641db 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -63,10 +63,10 @@ TEST_F(MindDataTestPipeline, TestCutOut) { uint64_t i = 0; while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image->shape(); - iter->GetNextRow(&row); + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); } EXPECT_EQ(i, 20); @@ -160,8 +160,9 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) { auto image = row["image"]; MS_LOG(INFO) << "Tensor image shape: " << image->shape(); // check if the image is in NCHW - EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1] - && 2268 == image->shape()[2] && 4032 == image->shape()[3], true); + EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1] && 2268 == image->shape()[2] && + 4032 == image->shape()[3], + true); iter->GetNextRow(&row); } EXPECT_EQ(i, 20); @@ -186,7 +187,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) { EXPECT_NE(one_hot_op, nullptr); // Create a Map operation on ds - ds = ds->Map({one_hot_op},{"label"}); + ds = ds->Map({one_hot_op}, {"label"}); EXPECT_NE(ds, nullptr); std::shared_ptr mixup_batch_op = vision::MixUpBatch(-1); @@ -209,7 +210,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { EXPECT_NE(one_hot_op, nullptr); // Create a Map operation on ds - ds = ds->Map({one_hot_op},{"label"}); + ds = ds->Map({one_hot_op}, {"label"}); EXPECT_NE(ds, nullptr); std::shared_ptr mixup_batch_op = vision::MixUpBatch(0.5); @@ -258,7 +259,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) { EXPECT_NE(one_hot_op, nullptr); // Create a Map operation on ds - ds = ds->Map({one_hot_op},{"label"}); + ds = ds->Map({one_hot_op}, {"label"}); EXPECT_NE(ds, nullptr); std::shared_ptr mixup_batch_op = vision::MixUpBatch(); @@ -379,10 +380,10 @@ TEST_F(MindDataTestPipeline, TestPad) { uint64_t i = 0; while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image->shape(); - iter->GetNextRow(&row); + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); } EXPECT_EQ(i, 20); @@ -504,6 +505,61 @@ TEST_F(MindDataTestPipeline, TestRandomAffineSuccess2) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestRandomColor) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomColor with non-default params."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_color_op_1 = vision::RandomColor(0.0, 0.0); + EXPECT_NE(random_color_op_1, nullptr); + + std::shared_ptr random_color_op_2 = vision::RandomColor(1.0, 0.1); + EXPECT_EQ(random_color_op_2, nullptr); + + std::shared_ptr random_color_op_3 = vision::RandomColor(0.0, 1.1); + EXPECT_NE(random_color_op_3, nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_color_op_1, random_color_op_3}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 20); + + // Manually terminate the pipeline + iter->Stop(); +} + TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; @@ -780,7 +836,8 @@ TEST_F(MindDataTestPipeline, TestRandomSolarize) { EXPECT_NE(ds, nullptr); // Create objects for the tensor ops - std::shared_ptr random_solarize = mindspore::dataset::api::vision::RandomSolarize(23, 23); //vision::RandomSolarize(); + std::shared_ptr random_solarize = + mindspore::dataset::api::vision::RandomSolarize(23, 23); // vision::RandomSolarize(); EXPECT_NE(random_solarize, nullptr); // Create a Map operation on ds diff --git a/tests/ut/cpp/dataset/random_color_op_test.cc b/tests/ut/cpp/dataset/random_color_op_test.cc new file mode 100644 index 0000000000..144174a49d --- /dev/null +++ b/tests/ut/cpp/dataset/random_color_op_test.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common.h" +#include "common/cvop_common.h" +#include "minddata/dataset/kernels/image/random_color_op.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +class MindDataTestRandomColorOp : public UT::CVOP::CVOpCommon { + public: + MindDataTestRandomColorOp() : CVOpCommon(), shape({3, 3, 3}) { + std::shared_ptr in; + std::shared_ptr gray; + + (void)Tensor::CreateEmpty(shape, DataType(DataType::DE_UINT8), &in); + (void)Tensor::CreateEmpty(shape, DataType(DataType::DE_UINT8), &input_tensor); + Status s = in->Fill(42); + s = input_tensor->Fill(42); + cvt_in = CVTensor::AsCVTensor(in); + cv::Mat m2; + auto m1 = cvt_in->mat(); + cv::cvtColor(m1, m2, cv::COLOR_RGB2GRAY); + cv::Mat temp[3] = {m2 , m2 , m2 }; + cv::Mat cv_out; + cv::merge(temp, 3, cv_out); + std::shared_ptr cvt_out; + CVTensor::CreateFromMat(cv_out, &cvt_out); + gray_tensor = std::static_pointer_cast(cvt_out); + } + TensorShape shape; + std::shared_ptr input_tensor; + std::shared_ptr cvt_in; + std::shared_ptr gray_tensor; +}; + +int64_t Compare(std::shared_ptr t1, std::shared_ptr t2) { + auto shape = t1->shape(); + int64_t sum = 0; + for (auto i = 0; i < shape[0]; i++) { + for (auto j = 0; j < shape[1]; j++) { + for (auto k = 0; k < shape[2]; k++) { + uint8_t value1; + uint8_t value2; + (void)t1->GetItemAt(&value1, {i, j, k}); + (void)t2->GetItemAt(&value2, {i, j, k}); + sum += abs(static_cast(value1) - static_cast(value2)); + } + } + } + return sum; +} + +// these tests are tautological, write better tests when the requirements for the output are determined +// e. g. how do we want to convert to gray and what does it mean to blend with a gray image (pre- post- gamma corrected, +// what weights). +TEST_F(MindDataTestRandomColorOp, TestOp1) { + std::shared_ptr output_tensor; + auto op = RandomColorOp(1, 1); + auto s = op.Compute(input_tensor, &output_tensor); + auto res = Compare(input_tensor, output_tensor); + EXPECT_EQ(0, res); +} + +TEST_F(MindDataTestRandomColorOp, TestOp2) { + std::shared_ptr output_tensor; + auto op = RandomColorOp(0, 0); + auto s = op.Compute(input_tensor, &output_tensor); + EXPECT_TRUE(s.IsOk()); + auto res = Compare(output_tensor, gray_tensor); + EXPECT_EQ(res, 0); +} + +TEST_F(MindDataTestRandomColorOp, TestOp3) { + std::shared_ptr output_tensor; + auto op = RandomColorOp(0.0, 1.0); + for (auto i = 0; i < 1; i++) { + auto s = op.Compute(input_tensor, &output_tensor); + EXPECT_TRUE(s.IsOk()); + } +} \ No newline at end of file diff --git a/tests/ut/data/dataset/golden/random_color_op_02_result.npz b/tests/ut/data/dataset/golden/random_color_op_02_result.npz new file mode 100644 index 0000000000..aaaeb83449 Binary files /dev/null and b/tests/ut/data/dataset/golden/random_color_op_02_result.npz differ diff --git a/tests/ut/python/dataset/test_random_color.py b/tests/ut/python/dataset/test_random_color.py index 0015e8498f..9b8be91630 100644 --- a/tests/ut/python/dataset/test_random_color.py +++ b/tests/ut/python/dataset/test_random_color.py @@ -16,9 +16,11 @@ Testing RandomColor op in DE """ import numpy as np +import pytest import mindspore.dataset as ds import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.py_transforms as F from mindspore import log as logger from util import visualize_list, diff_mse, save_and_check_md5, \ @@ -26,11 +28,17 @@ from util import visualize_list, diff_mse, save_and_check_md5, \ DATA_DIR = "../data/dataset/testImageNetData/train/" +C_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +C_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + +MNIST_DATA_DIR = "../data/dataset/testMnistData" + GENERATE_GOLDEN = False -def test_random_color(degrees=(0.1, 1.9), plot=False): + +def test_random_color_py(degrees=(0.1, 1.9), plot=False): """ - Test RandomColor + Test Python RandomColor """ logger.info("Test RandomColor") @@ -85,9 +93,53 @@ def test_random_color(degrees=(0.1, 1.9), plot=False): visualize_list(images_original, images_random_color) -def test_random_color_md5(): +def test_random_color_c(degrees=(0.1, 1.9), plot=False, run_golden=True): """ - Test RandomColor with md5 check + Test Cpp RandomColor + """ + logger.info("test_random_color_op") + + original_seed = config_get_set_seed(10) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + # Decode with rgb format set to True + data1 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False) + data2 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False) + + # Serialize and Load dataset requires using vision.Decode instead of vision.Decode(). + if degrees is None: + c_op = vision.RandomColor() + else: + c_op = vision.RandomColor(degrees) + + data1 = data1.map(input_columns=["image"], operations=[vision.Decode()]) + data2 = data2.map(input_columns=["image"], operations=[vision.Decode(), c_op]) + + image_random_color_op = [] + image = [] + + for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): + actual = item1["image"] + expected = item2["image"] + image.append(actual) + image_random_color_op.append(expected) + + if run_golden: + # Compare with expected md5 from images + filename = "random_color_op_02_result.npz" + save_and_check_md5(data2, filename, generate_golden=GENERATE_GOLDEN) + + if plot: + visualize_list(image, image_random_color_op) + + # Restore configuration + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers((original_num_parallel_workers)) + + +def test_random_color_py_md5(): + """ + Test Python RandomColor with md5 check """ logger.info("Test RandomColor with md5 check") original_seed = config_get_set_seed(10) @@ -110,8 +162,94 @@ def test_random_color_md5(): ds.config.set_num_parallel_workers((original_num_parallel_workers)) +def test_compare_random_color_op(degrees=None, plot=False): + """ + Compare Random Color op in Python and Cpp + """ + + logger.info("test_random_color_op") + + original_seed = config_get_set_seed(5) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + # Decode with rgb format set to True + data1 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False) + data2 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False) + + if degrees is None: + c_op = vision.RandomColor() + p_op = F.RandomColor() + else: + c_op = vision.RandomColor(degrees) + p_op = F.RandomColor(degrees) + + transforms_random_color_py = F.ComposeOp([lambda img: img.astype(np.uint8), F.ToPIL(), + p_op, np.array]) + + data1 = data1.map(input_columns=["image"], operations=[vision.Decode(), c_op]) + data2 = data2.map(input_columns=["image"], operations=[vision.Decode()]) + data2 = data2.map(input_columns=["image"], operations=transforms_random_color_py()) + + image_random_color_op = [] + image = [] + + for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): + actual = item1["image"] + expected = item2["image"] + image_random_color_op.append(actual) + image.append(expected) + assert actual.shape == expected.shape + mse = diff_mse(actual, expected) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + # Restore configuration + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + if plot: + visualize_list(image, image_random_color_op) + + +def test_random_color_c_errors(): + """ + Test that Cpp RandomColor errors with bad input + """ + with pytest.raises(TypeError) as error_info: + vision.RandomColor((12)) + assert "degrees must be either a tuple or a list." in str(error_info.value) + + with pytest.raises(TypeError) as error_info: + vision.RandomColor(("col", 3)) + assert "Argument degrees[0] with value col is not of type (, )." in str( + error_info.value) + + with pytest.raises(ValueError) as error_info: + vision.RandomColor((0.9, 0.1)) + assert "degrees should be in (min,max) format. Got (max,min)." in str(error_info.value) + + with pytest.raises(ValueError) as error_info: + vision.RandomColor((0.9,)) + assert "degrees must be a sequence with length 2." in str(error_info.value) + + # RandomColor Cpp Op will fail with one channel input + mnist_ds = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) + mnist_ds = mnist_ds.map(input_columns="image", operations=vision.RandomColor()) + + with pytest.raises(RuntimeError) as error_info: + for _ in enumerate(mnist_ds): + pass + assert "Invalid number of channels in input image" in str(error_info.value) + + if __name__ == "__main__": - test_random_color() - test_random_color(plot=True) - test_random_color(degrees=(0.5, 1.5), plot=True) - test_random_color_md5() + test_random_color_py() + test_random_color_py(plot=True) + test_random_color_py(degrees=(0.5, 1.5), plot=True) + test_random_color_py_md5() + + test_random_color_c() + test_random_color_c(plot=True) + test_random_color_c(degrees=(0.5, 1.5), plot=True, run_golden=False) + test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False) + test_compare_random_color_op(plot=True) + test_random_color_c_errors()