Merge pull request !3769 from Alexey_Shevlyakov/ColorOptags/v0.7.0-beta
| @@ -32,6 +32,7 @@ | |||||
| #include "minddata/dataset/kernels/image/normalize_op.h" | #include "minddata/dataset/kernels/image/normalize_op.h" | ||||
| #include "minddata/dataset/kernels/image/pad_op.h" | #include "minddata/dataset/kernels/image/pad_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_affine_op.h" | #include "minddata/dataset/kernels/image/random_affine_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_color_op.h" | |||||
| #include "minddata/dataset/kernels/image/random_color_adjust_op.h" | #include "minddata/dataset/kernels/image/random_color_adjust_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" | #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" | #include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" | ||||
| @@ -273,6 +274,14 @@ PYBIND_REGISTER( | |||||
| py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); | py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(RandomColorOp, 1, ([](const py::module *m) { | |||||
| (void)py::class_<RandomColorOp, TensorOp, std::shared_ptr<RandomColorOp>>( | |||||
| *m, "RandomColorOp", | |||||
| "Tensor operation to blend an image with its grayscale version with random weights" | |||||
| "Takes min and max for the range of random weights") | |||||
| .def(py::init<float, float>(), py::arg("min"), py::arg("max")); | |||||
| })); | |||||
| PYBIND_REGISTER(RandomColorAdjustOp, 1, ([](const py::module *m) { | PYBIND_REGISTER(RandomColorAdjustOp, 1, ([](const py::module *m) { | ||||
| (void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>( | (void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>( | ||||
| *m, "RandomColorAdjustOp", | *m, "RandomColorAdjustOp", | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "minddata/dataset/kernels/data/one_hot_op.h" | #include "minddata/dataset/kernels/data/one_hot_op.h" | ||||
| #include "minddata/dataset/kernels/image/pad_op.h" | #include "minddata/dataset/kernels/image/pad_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_affine_op.h" | #include "minddata/dataset/kernels/image/random_affine_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_color_op.h" | |||||
| #include "minddata/dataset/kernels/image/random_color_adjust_op.h" | #include "minddata/dataset/kernels/image/random_color_adjust_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_crop_op.h" | #include "minddata/dataset/kernels/image/random_crop_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" | #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" | ||||
| @@ -140,6 +141,21 @@ std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint | |||||
| return op; | return op; | ||||
| } | } | ||||
| // Function to create RandomColorOperation. | |||||
| std::shared_ptr<RandomColorOperation> RandomColor(float t_lb, float t_ub) { | |||||
| auto op = std::make_shared<RandomColorOperation>(t_lb, t_ub); | |||||
| // Input validation | |||||
| if (!op->ValidateParams()) { | |||||
| return nullptr; | |||||
| } | |||||
| return op; | |||||
| } | |||||
| std::shared_ptr<TensorOp> RandomColorOperation::Build() { | |||||
| std::shared_ptr<RandomColorOp> tensor_op = std::make_shared<RandomColorOp>(t_lb_, t_ub_); | |||||
| return tensor_op; | |||||
| } | |||||
| // Function to create RandomColorAdjustOperation. | // Function to create RandomColorAdjustOperation. | ||||
| std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness, | std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness, | ||||
| std::vector<float> contrast, | std::vector<float> contrast, | ||||
| @@ -475,6 +491,18 @@ std::shared_ptr<TensorOp> PadOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| // RandomColorOperation. | |||||
| RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {} | |||||
| bool RandomColorOperation::ValidateParams() { | |||||
| // Do some input validation. | |||||
| if (t_lb_ > t_ub_) { | |||||
| MS_LOG(ERROR) << "RandomColor: lower bound must be less or equal to upper bound"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| // RandomColorAdjustOperation. | // RandomColorAdjustOperation. | ||||
| RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast, | RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast, | ||||
| std::vector<float> saturation, std::vector<float> hue) | std::vector<float> saturation, std::vector<float> hue) | ||||
| @@ -70,7 +70,7 @@ class CVTensor : public Tensor { | |||||
| /// Get a reference to the CV::Mat | /// Get a reference to the CV::Mat | ||||
| /// \return a reference to the internal CV::Mat | /// \return a reference to the internal CV::Mat | ||||
| cv::Mat mat() const { return mat_; } | |||||
| cv::Mat &mat() { return mat_; } | |||||
| /// Get a copy of the CV::Mat | /// Get a copy of the CV::Mat | ||||
| /// \return a copy of internal CV::Mat | /// \return a copy of internal CV::Mat | ||||
| @@ -57,6 +57,7 @@ class NormalizeOperation; | |||||
| class OneHotOperation; | class OneHotOperation; | ||||
| class PadOperation; | class PadOperation; | ||||
| class RandomAffineOperation; | class RandomAffineOperation; | ||||
| class RandomColorOperation; | |||||
| class RandomColorAdjustOperation; | class RandomColorAdjustOperation; | ||||
| class RandomCropOperation; | class RandomCropOperation; | ||||
| class RandomHorizontalFlipOperation; | class RandomHorizontalFlipOperation; | ||||
| @@ -162,6 +163,14 @@ std::shared_ptr<RandomAffineOperation> RandomAffine( | |||||
| InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, | InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, | ||||
| const std::vector<uint8_t> &fill_value = {0, 0, 0}); | const std::vector<uint8_t> &fill_value = {0, 0, 0}); | ||||
| /// \brief Blends an image with its grayscale version with random weights | |||||
| /// t and 1 - t generated from a given range. If the range is trivial | |||||
| /// then the weights are determinate and t equals the bound of the interval | |||||
| /// \param[in] t_lb lower bound on the range of random weights | |||||
| /// \param[in] t_lb upper bound on the range of random weights | |||||
| /// \return Shared pointer to the current TensorOp | |||||
| std::shared_ptr<RandomColorOperation> RandomColor(float t_lb, float t_ub); | |||||
| /// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image | /// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image | ||||
| /// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values | /// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values | ||||
| /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} | /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} | ||||
| @@ -417,6 +426,21 @@ class RandomAffineOperation : public TensorOperation { | |||||
| std::vector<uint8_t> fill_value_; | std::vector<uint8_t> fill_value_; | ||||
| }; | }; | ||||
| class RandomColorOperation : public TensorOperation { | |||||
| public: | |||||
| RandomColorOperation(float t_lb, float t_ub); | |||||
| ~RandomColorOperation() = default; | |||||
| std::shared_ptr<TensorOp> Build() override; | |||||
| bool ValidateParams() override; | |||||
| private: | |||||
| float t_lb_; | |||||
| float t_ub_; | |||||
| }; | |||||
| class RandomColorAdjustOperation : public TensorOperation { | class RandomColorAdjustOperation : public TensorOperation { | ||||
| public: | public: | ||||
| RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0}, | RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0}, | ||||
| @@ -44,5 +44,6 @@ add_library(kernels-image OBJECT | |||||
| uniform_aug_op.cc | uniform_aug_op.cc | ||||
| resize_with_bbox_op.cc | resize_with_bbox_op.cc | ||||
| random_resize_with_bbox_op.cc | random_resize_with_bbox_op.cc | ||||
| random_color_op.cc | |||||
| ) | ) | ||||
| add_dependencies(kernels-image kernels-soft-dvpp-image) | add_dependencies(kernels-image kernels-soft-dvpp-image) | ||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/random_color_op.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {} | |||||
| Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) { | |||||
| IO_CHECK(in, out); | |||||
| if (in->Rank() != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("image must have 3 channels"); | |||||
| } | |||||
| // 0.5 pixel precision assuming an 8 bit image | |||||
| const auto eps = 0.00195; | |||||
| const auto t = dist_(rnd_); | |||||
| if (abs(t - 1.0) < eps) { | |||||
| // Just return input? Can we do it given that input would otherwise get consumed in CVTensor constructor anyway? | |||||
| *out = in; | |||||
| return Status::OK(); | |||||
| } | |||||
| auto cvt_in = CVTensor::AsCVTensor(in); | |||||
| auto m1 = cvt_in->mat(); | |||||
| cv::Mat gray; | |||||
| // gray is allocated without using the allocator | |||||
| cv::cvtColor(m1, gray, cv::COLOR_RGB2GRAY); | |||||
| // luminosity is not preserved, consider using weights. | |||||
| cv::Mat temp[3] = {gray, gray, gray}; | |||||
| cv::Mat cv_out; | |||||
| cv::merge(temp, 3, cv_out); | |||||
| std::shared_ptr<CVTensor> cvt_out; | |||||
| CVTensor::CreateFromMat(cv_out, &cvt_out); | |||||
| if (abs(t - 0.0) < eps) { | |||||
| // return grayscale | |||||
| *out = std::static_pointer_cast<Tensor>(cvt_out); | |||||
| return Status::OK(); | |||||
| } | |||||
| // return blended image. addWeighted takes care of overflow for uint8_t | |||||
| cv::addWeighted(m1, t, cvt_out->mat(), 1 - t, 0, cvt_out->mat()); | |||||
| *out = std::static_pointer_cast<Tensor>(cvt_out); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H | |||||
| #include <memory> | |||||
| #include <random> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <opencv2/imgproc/imgproc.hpp> | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | |||||
| #include "minddata/dataset/kernels/tensor_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \class RandomColorOp random_color_op.h | |||||
| /// \brief Blends an image with its grayscale version with random weights | |||||
| /// t and 1 - t generated from a given range. | |||||
| /// If the range is trivial then the weights are determinate and | |||||
| /// t equals the bound of the interval | |||||
| class RandomColorOp : public TensorOp { | |||||
| public: | |||||
| RandomColorOp() = default; | |||||
| /// \brief Constructor | |||||
| /// \param[in] t_lb lower bound for the random weights | |||||
| /// \param[in] t_ub upper bound for the random weights | |||||
| RandomColorOp(float t_lb, float t_ub); | |||||
| /// \brief the main function performing computations | |||||
| /// \param[in] in 2- or 3- dimensional tensor representing an image | |||||
| /// \param[out] out 2- or 3- dimensional tensor representing an image | |||||
| /// with the same dimensions as in | |||||
| Status Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) override; | |||||
| /// \brief returns the name of the op | |||||
| std::string Name() const override { return kRandomColorOp; } | |||||
| private: | |||||
| std::mt19937 rnd_; | |||||
| std::uniform_real_distribution<float> dist_; | |||||
| float t_lb_; | |||||
| float t_ub_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_COLOR_OP_H | |||||
| @@ -129,6 +129,7 @@ constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; | |||||
| constexpr char kUniformAugOp[] = "UniformAugOp"; | constexpr char kUniformAugOp[] = "UniformAugOp"; | ||||
| constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp"; | constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp"; | ||||
| constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp"; | constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp"; | ||||
| constexpr char kRandomColorOp[] = "RandomColorOp"; | |||||
| // text | // text | ||||
| constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp"; | constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp"; | ||||
| @@ -46,7 +46,8 @@ import mindspore._c_dataengine as cde | |||||
| from .utils import Inter, Border | from .utils import Inter, Border | ||||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | ||||
| check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | ||||
| check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \ | |||||
| check_range, check_resize, check_rescale, check_pad, check_cutout, \ | |||||
| check_uniform_augment_cpp, \ | |||||
| check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ | check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ | ||||
| check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER | check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER | ||||
| @@ -628,6 +629,21 @@ class CenterCrop(cde.CenterCropOp): | |||||
| super().__init__(*size) | super().__init__(*size) | ||||
| class RandomColor(cde.RandomColorOp): | |||||
| """ | |||||
| Adjust the color of the input image by a fixed or random degree. | |||||
| Args: | |||||
| degrees (sequence): Range of random color adjustment degrees. | |||||
| It should be in (min, max) format. If min=max, then it is a | |||||
| single fixed magnitude operation (default=(0.1,1.9)). | |||||
| Works with 3-channel color images. | |||||
| """ | |||||
| @check_positive_degrees | |||||
| def __init__(self, degrees=(0.1, 1.9)): | |||||
| super().__init__(*degrees) | |||||
| class RandomColorAdjust(cde.RandomColorAdjustOp): | class RandomColorAdjust(cde.RandomColorAdjustOp): | ||||
| """ | """ | ||||
| Randomly adjust the brightness, contrast, saturation, and hue of the input image. | Randomly adjust the brightness, contrast, saturation, and hue of the input image. | ||||
| @@ -609,21 +609,23 @@ def check_uniform_augment_py(method): | |||||
| def check_positive_degrees(method): | def check_positive_degrees(method): | ||||
| """A wrapper method to check degrees parameter in RandSharpness and RandColor""" | |||||
| """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (python and cpp)""" | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| [degrees], _ = parse_user_args(method, *args, **kwargs) | [degrees], _ = parse_user_args(method, *args, **kwargs) | ||||
| if isinstance(degrees, (list, tuple)): | |||||
| if degrees is not None: | |||||
| if not isinstance(degrees, (list, tuple)): | |||||
| raise TypeError("degrees must be either a tuple or a list.") | |||||
| type_check_list(degrees, (int, float), "degrees") | |||||
| if len(degrees) != 2: | if len(degrees) != 2: | ||||
| raise ValueError("Degrees must be a sequence with length 2.") | |||||
| for value in degrees: | |||||
| check_value(value, (0., FLOAT_MAX_INTEGER)) | |||||
| check_positive(degrees[0], "degrees[0]") | |||||
| raise ValueError("degrees must be a sequence with length 2.") | |||||
| for degree in degrees: | |||||
| check_value(degree, (0, FLOAT_MAX_INTEGER)) | |||||
| if degrees[0] > degrees[1]: | if degrees[0] > degrees[1]: | ||||
| raise ValueError("Degrees should be in (min,max) format. Got (max,min).") | |||||
| else: | |||||
| raise TypeError("Degrees should be a tuple or list.") | |||||
| raise ValueError("degrees should be in (min,max) format. Got (max,min).") | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -698,4 +700,5 @@ def check_random_solarize(method): | |||||
| raise ValueError("threshold must be in min max format numbers") | raise ValueError("threshold must be in min max format numbers") | ||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -39,6 +39,7 @@ SET(DE_UT_SRCS | |||||
| project_op_test.cc | project_op_test.cc | ||||
| queue_test.cc | queue_test.cc | ||||
| random_affine_op_test.cc | random_affine_op_test.cc | ||||
| random_color_op_test.cc | |||||
| random_crop_op_test.cc | random_crop_op_test.cc | ||||
| random_crop_with_bbox_op_test.cc | random_crop_with_bbox_op_test.cc | ||||
| random_crop_decode_resize_op_test.cc | random_crop_decode_resize_op_test.cc | ||||
| @@ -63,10 +63,10 @@ TEST_F(MindDataTestPipeline, TestCutOut) { | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | } | ||||
| EXPECT_EQ(i, 20); | EXPECT_EQ(i, 20); | ||||
| @@ -160,8 +160,9 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) { | |||||
| auto image = row["image"]; | auto image = row["image"]; | ||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | ||||
| // check if the image is in NCHW | // check if the image is in NCHW | ||||
| EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1] | |||||
| && 2268 == image->shape()[2] && 4032 == image->shape()[3], true); | |||||
| EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1] && 2268 == image->shape()[2] && | |||||
| 4032 == image->shape()[3], | |||||
| true); | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| } | } | ||||
| EXPECT_EQ(i, 20); | EXPECT_EQ(i, 20); | ||||
| @@ -186,7 +187,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) { | |||||
| EXPECT_NE(one_hot_op, nullptr); | EXPECT_NE(one_hot_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({one_hot_op},{"label"}); | |||||
| ds = ds->Map({one_hot_op}, {"label"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(-1); | std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(-1); | ||||
| @@ -209,7 +210,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { | |||||
| EXPECT_NE(one_hot_op, nullptr); | EXPECT_NE(one_hot_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({one_hot_op},{"label"}); | |||||
| ds = ds->Map({one_hot_op}, {"label"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5); | std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5); | ||||
| @@ -258,7 +259,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) { | |||||
| EXPECT_NE(one_hot_op, nullptr); | EXPECT_NE(one_hot_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({one_hot_op},{"label"}); | |||||
| ds = ds->Map({one_hot_op}, {"label"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(); | std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(); | ||||
| @@ -379,10 +380,10 @@ TEST_F(MindDataTestPipeline, TestPad) { | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | } | ||||
| EXPECT_EQ(i, 20); | EXPECT_EQ(i, 20); | ||||
| @@ -504,6 +505,61 @@ TEST_F(MindDataTestPipeline, TestRandomAffineSuccess2) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestRandomColor) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomColor with non-default params."; | |||||
| // Create an ImageFolder Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create a Repeat operation on ds | |||||
| int32_t repeat_num = 2; | |||||
| ds = ds->Repeat(repeat_num); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create objects for the tensor ops | |||||
| std::shared_ptr<TensorOperation> random_color_op_1 = vision::RandomColor(0.0, 0.0); | |||||
| EXPECT_NE(random_color_op_1, nullptr); | |||||
| std::shared_ptr<TensorOperation> random_color_op_2 = vision::RandomColor(1.0, 0.1); | |||||
| EXPECT_EQ(random_color_op_2, nullptr); | |||||
| std::shared_ptr<TensorOperation> random_color_op_3 = vision::RandomColor(0.0, 1.1); | |||||
| EXPECT_NE(random_color_op_3, nullptr); | |||||
| // Create a Map operation on ds | |||||
| ds = ds->Map({random_color_op_1, random_color_op_3}); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create a Batch operation on ds | |||||
| int32_t batch_size = 1; | |||||
| ds = ds->Batch(batch_size); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| EXPECT_EQ(i, 20); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { | TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { | ||||
| // Create an ImageFolder Dataset | // Create an ImageFolder Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | std::string folder_path = datasets_root_path_ + "/testPK/data/"; | ||||
| @@ -780,7 +836,8 @@ TEST_F(MindDataTestPipeline, TestRandomSolarize) { | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create objects for the tensor ops | // Create objects for the tensor ops | ||||
| std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(23, 23); //vision::RandomSolarize(); | |||||
| std::shared_ptr<TensorOperation> random_solarize = | |||||
| mindspore::dataset::api::vision::RandomSolarize(23, 23); // vision::RandomSolarize(); | |||||
| EXPECT_NE(random_solarize, nullptr); | EXPECT_NE(random_solarize, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| @@ -0,0 +1,99 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common.h" | |||||
| #include "common/cvop_common.h" | |||||
| #include "minddata/dataset/kernels/image/random_color_op.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | |||||
| #include "utils/log_adapter.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::LogStream; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| class MindDataTestRandomColorOp : public UT::CVOP::CVOpCommon { | |||||
| public: | |||||
| MindDataTestRandomColorOp() : CVOpCommon(), shape({3, 3, 3}) { | |||||
| std::shared_ptr<Tensor> in; | |||||
| std::shared_ptr<Tensor> gray; | |||||
| (void)Tensor::CreateEmpty(shape, DataType(DataType::DE_UINT8), &in); | |||||
| (void)Tensor::CreateEmpty(shape, DataType(DataType::DE_UINT8), &input_tensor); | |||||
| Status s = in->Fill<uint8_t>(42); | |||||
| s = input_tensor->Fill<uint8_t>(42); | |||||
| cvt_in = CVTensor::AsCVTensor(in); | |||||
| cv::Mat m2; | |||||
| auto m1 = cvt_in->mat(); | |||||
| cv::cvtColor(m1, m2, cv::COLOR_RGB2GRAY); | |||||
| cv::Mat temp[3] = {m2 , m2 , m2 }; | |||||
| cv::Mat cv_out; | |||||
| cv::merge(temp, 3, cv_out); | |||||
| std::shared_ptr<CVTensor> cvt_out; | |||||
| CVTensor::CreateFromMat(cv_out, &cvt_out); | |||||
| gray_tensor = std::static_pointer_cast<Tensor>(cvt_out); | |||||
| } | |||||
| TensorShape shape; | |||||
| std::shared_ptr<Tensor> input_tensor; | |||||
| std::shared_ptr<CVTensor> cvt_in; | |||||
| std::shared_ptr<Tensor> gray_tensor; | |||||
| }; | |||||
| int64_t Compare(std::shared_ptr<Tensor> t1, std::shared_ptr<Tensor> t2) { | |||||
| auto shape = t1->shape(); | |||||
| int64_t sum = 0; | |||||
| for (auto i = 0; i < shape[0]; i++) { | |||||
| for (auto j = 0; j < shape[1]; j++) { | |||||
| for (auto k = 0; k < shape[2]; k++) { | |||||
| uint8_t value1; | |||||
| uint8_t value2; | |||||
| (void)t1->GetItemAt<uint8_t>(&value1, {i, j, k}); | |||||
| (void)t2->GetItemAt<uint8_t>(&value2, {i, j, k}); | |||||
| sum += abs(static_cast<int>(value1) - static_cast<int>(value2)); | |||||
| } | |||||
| } | |||||
| } | |||||
| return sum; | |||||
| } | |||||
| // these tests are tautological, write better tests when the requirements for the output are determined | |||||
| // e. g. how do we want to convert to gray and what does it mean to blend with a gray image (pre- post- gamma corrected, | |||||
| // what weights). | |||||
| TEST_F(MindDataTestRandomColorOp, TestOp1) { | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| auto op = RandomColorOp(1, 1); | |||||
| auto s = op.Compute(input_tensor, &output_tensor); | |||||
| auto res = Compare(input_tensor, output_tensor); | |||||
| EXPECT_EQ(0, res); | |||||
| } | |||||
| TEST_F(MindDataTestRandomColorOp, TestOp2) { | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| auto op = RandomColorOp(0, 0); | |||||
| auto s = op.Compute(input_tensor, &output_tensor); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| auto res = Compare(output_tensor, gray_tensor); | |||||
| EXPECT_EQ(res, 0); | |||||
| } | |||||
| TEST_F(MindDataTestRandomColorOp, TestOp3) { | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| auto op = RandomColorOp(0.0, 1.0); | |||||
| for (auto i = 0; i < 1; i++) { | |||||
| auto s = op.Compute(input_tensor, &output_tensor); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| } | |||||
| } | |||||
| @@ -16,9 +16,11 @@ | |||||
| Testing RandomColor op in DE | Testing RandomColor op in DE | ||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | import mindspore.dataset.transforms.vision.py_transforms as F | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import visualize_list, diff_mse, save_and_check_md5, \ | from util import visualize_list, diff_mse, save_and_check_md5, \ | ||||
| @@ -26,11 +28,17 @@ from util import visualize_list, diff_mse, save_and_check_md5, \ | |||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | DATA_DIR = "../data/dataset/testImageNetData/train/" | ||||
| C_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||||
| C_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| MNIST_DATA_DIR = "../data/dataset/testMnistData" | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| def test_random_color(degrees=(0.1, 1.9), plot=False): | |||||
| def test_random_color_py(degrees=(0.1, 1.9), plot=False): | |||||
| """ | """ | ||||
| Test RandomColor | |||||
| Test Python RandomColor | |||||
| """ | """ | ||||
| logger.info("Test RandomColor") | logger.info("Test RandomColor") | ||||
| @@ -85,9 +93,53 @@ def test_random_color(degrees=(0.1, 1.9), plot=False): | |||||
| visualize_list(images_original, images_random_color) | visualize_list(images_original, images_random_color) | ||||
| def test_random_color_md5(): | |||||
| def test_random_color_c(degrees=(0.1, 1.9), plot=False, run_golden=True): | |||||
| """ | """ | ||||
| Test RandomColor with md5 check | |||||
| Test Cpp RandomColor | |||||
| """ | |||||
| logger.info("test_random_color_op") | |||||
| original_seed = config_get_set_seed(10) | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| # Decode with rgb format set to True | |||||
| data1 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| # Serialize and Load dataset requires using vision.Decode instead of vision.Decode(). | |||||
| if degrees is None: | |||||
| c_op = vision.RandomColor() | |||||
| else: | |||||
| c_op = vision.RandomColor(degrees) | |||||
| data1 = data1.map(input_columns=["image"], operations=[vision.Decode()]) | |||||
| data2 = data2.map(input_columns=["image"], operations=[vision.Decode(), c_op]) | |||||
| image_random_color_op = [] | |||||
| image = [] | |||||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||||
| actual = item1["image"] | |||||
| expected = item2["image"] | |||||
| image.append(actual) | |||||
| image_random_color_op.append(expected) | |||||
| if run_golden: | |||||
| # Compare with expected md5 from images | |||||
| filename = "random_color_op_02_result.npz" | |||||
| save_and_check_md5(data2, filename, generate_golden=GENERATE_GOLDEN) | |||||
| if plot: | |||||
| visualize_list(image, image_random_color_op) | |||||
| # Restore configuration | |||||
| ds.config.set_seed(original_seed) | |||||
| ds.config.set_num_parallel_workers((original_num_parallel_workers)) | |||||
| def test_random_color_py_md5(): | |||||
| """ | |||||
| Test Python RandomColor with md5 check | |||||
| """ | """ | ||||
| logger.info("Test RandomColor with md5 check") | logger.info("Test RandomColor with md5 check") | ||||
| original_seed = config_get_set_seed(10) | original_seed = config_get_set_seed(10) | ||||
| @@ -110,8 +162,94 @@ def test_random_color_md5(): | |||||
| ds.config.set_num_parallel_workers((original_num_parallel_workers)) | ds.config.set_num_parallel_workers((original_num_parallel_workers)) | ||||
| def test_compare_random_color_op(degrees=None, plot=False): | |||||
| """ | |||||
| Compare Random Color op in Python and Cpp | |||||
| """ | |||||
| logger.info("test_random_color_op") | |||||
| original_seed = config_get_set_seed(5) | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| # Decode with rgb format set to True | |||||
| data1 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = ds.TFRecordDataset(C_DATA_DIR, C_SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| if degrees is None: | |||||
| c_op = vision.RandomColor() | |||||
| p_op = F.RandomColor() | |||||
| else: | |||||
| c_op = vision.RandomColor(degrees) | |||||
| p_op = F.RandomColor(degrees) | |||||
| transforms_random_color_py = F.ComposeOp([lambda img: img.astype(np.uint8), F.ToPIL(), | |||||
| p_op, np.array]) | |||||
| data1 = data1.map(input_columns=["image"], operations=[vision.Decode(), c_op]) | |||||
| data2 = data2.map(input_columns=["image"], operations=[vision.Decode()]) | |||||
| data2 = data2.map(input_columns=["image"], operations=transforms_random_color_py()) | |||||
| image_random_color_op = [] | |||||
| image = [] | |||||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||||
| actual = item1["image"] | |||||
| expected = item2["image"] | |||||
| image_random_color_op.append(actual) | |||||
| image.append(expected) | |||||
| assert actual.shape == expected.shape | |||||
| mse = diff_mse(actual, expected) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| # Restore configuration | |||||
| ds.config.set_seed(original_seed) | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| if plot: | |||||
| visualize_list(image, image_random_color_op) | |||||
| def test_random_color_c_errors(): | |||||
| """ | |||||
| Test that Cpp RandomColor errors with bad input | |||||
| """ | |||||
| with pytest.raises(TypeError) as error_info: | |||||
| vision.RandomColor((12)) | |||||
| assert "degrees must be either a tuple or a list." in str(error_info.value) | |||||
| with pytest.raises(TypeError) as error_info: | |||||
| vision.RandomColor(("col", 3)) | |||||
| assert "Argument degrees[0] with value col is not of type (<class 'int'>, <class 'float'>)." in str( | |||||
| error_info.value) | |||||
| with pytest.raises(ValueError) as error_info: | |||||
| vision.RandomColor((0.9, 0.1)) | |||||
| assert "degrees should be in (min,max) format. Got (max,min)." in str(error_info.value) | |||||
| with pytest.raises(ValueError) as error_info: | |||||
| vision.RandomColor((0.9,)) | |||||
| assert "degrees must be a sequence with length 2." in str(error_info.value) | |||||
| # RandomColor Cpp Op will fail with one channel input | |||||
| mnist_ds = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=vision.RandomColor()) | |||||
| with pytest.raises(RuntimeError) as error_info: | |||||
| for _ in enumerate(mnist_ds): | |||||
| pass | |||||
| assert "Invalid number of channels in input image" in str(error_info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_random_color() | |||||
| test_random_color(plot=True) | |||||
| test_random_color(degrees=(0.5, 1.5), plot=True) | |||||
| test_random_color_md5() | |||||
| test_random_color_py() | |||||
| test_random_color_py(plot=True) | |||||
| test_random_color_py(degrees=(0.5, 1.5), plot=True) | |||||
| test_random_color_py_md5() | |||||
| test_random_color_c() | |||||
| test_random_color_c(plot=True) | |||||
| test_random_color_c(degrees=(0.5, 1.5), plot=True, run_golden=False) | |||||
| test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False) | |||||
| test_compare_random_color_op(plot=True) | |||||
| test_random_color_c_errors() | |||||