| @@ -43,6 +43,7 @@ | |||||
| #include "minddata/dataset/kernels/image/random_resize_op.h" | #include "minddata/dataset/kernels/image/random_resize_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" | #include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_rotation_op.h" | #include "minddata/dataset/kernels/image/random_rotation_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_sharpness_op.h" | |||||
| #include "minddata/dataset/kernels/image/random_select_subpolicy_op.h" | #include "minddata/dataset/kernels/image/random_select_subpolicy_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_solarize_op.h" | #include "minddata/dataset/kernels/image/random_solarize_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_vertical_flip_op.h" | #include "minddata/dataset/kernels/image/random_vertical_flip_op.h" | ||||
| @@ -333,6 +334,15 @@ PYBIND_REGISTER(RandomRotationOp, 1, ([](const py::module *m) { | |||||
| py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB); | py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(RandomSharpnessOp, 1, ([](const py::module *m) { | |||||
| (void)py::class_<RandomSharpnessOp, TensorOp, std::shared_ptr<RandomSharpnessOp>>( | |||||
| *m, "RandomSharpnessOp", | |||||
| "Tensor operation to apply RandomSharpness." | |||||
| "Takes a range for degrees") | |||||
| .def(py::init<float, float>(), py::arg("startDegree") = RandomSharpnessOp::kDefStartDegree, | |||||
| py::arg("endDegree") = RandomSharpnessOp::kDefEndDegree); | |||||
| })); | |||||
| PYBIND_REGISTER(RandomSelectSubpolicyOp, 1, ([](const py::module *m) { | PYBIND_REGISTER(RandomSelectSubpolicyOp, 1, ([](const py::module *m) { | ||||
| (void)py::class_<RandomSelectSubpolicyOp, TensorOp, std::shared_ptr<RandomSelectSubpolicyOp>>( | (void)py::class_<RandomSelectSubpolicyOp, TensorOp, std::shared_ptr<RandomSelectSubpolicyOp>>( | ||||
| *m, "RandomSelectSubpolicyOp") | *m, "RandomSelectSubpolicyOp") | ||||
| @@ -31,6 +31,7 @@ | |||||
| #include "minddata/dataset/kernels/image/random_crop_op.h" | #include "minddata/dataset/kernels/image/random_crop_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" | #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_rotation_op.h" | #include "minddata/dataset/kernels/image/random_rotation_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_sharpness_op.h" | |||||
| #include "minddata/dataset/kernels/image/random_solarize_op.h" | #include "minddata/dataset/kernels/image/random_solarize_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_vertical_flip_op.h" | #include "minddata/dataset/kernels/image/random_vertical_flip_op.h" | ||||
| #include "minddata/dataset/kernels/image/resize_op.h" | #include "minddata/dataset/kernels/image/resize_op.h" | ||||
| @@ -209,6 +210,16 @@ std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min, u | |||||
| return op; | return op; | ||||
| } | } | ||||
| // Function to create RandomSharpnessOperation. | |||||
| std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees) { | |||||
| auto op = std::make_shared<RandomSharpnessOperation>(degrees); | |||||
| // Input validation | |||||
| if (!op->ValidateParams()) { | |||||
| return nullptr; | |||||
| } | |||||
| return op; | |||||
| } | |||||
| // Function to create RandomVerticalFlipOperation. | // Function to create RandomVerticalFlipOperation. | ||||
| std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) { | std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) { | ||||
| auto op = std::make_shared<RandomVerticalFlipOperation>(prob); | auto op = std::make_shared<RandomVerticalFlipOperation>(prob); | ||||
| @@ -665,6 +676,22 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| // Function to create RandomSharpness. | |||||
| RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {} | |||||
| bool RandomSharpnessOperation::ValidateParams() { | |||||
| if (degrees_.empty() || degrees_.size() != 2) { | |||||
| MS_LOG(ERROR) << "RandomSharpness: degrees vector has incorrect size: degrees.size()"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() { | |||||
| std::shared_ptr<RandomSharpnessOp> tensor_op = std::make_shared<RandomSharpnessOp>(degrees_[0], degrees_[1]); | |||||
| return tensor_op; | |||||
| } | |||||
| // RandomSolarizeOperation. | // RandomSolarizeOperation. | ||||
| RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max) | RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max) | ||||
| : threshold_min_(threshold_min), threshold_max_(threshold_max) {} | : threshold_min_(threshold_min), threshold_max_(threshold_max) {} | ||||
| @@ -61,6 +61,7 @@ class RandomColorAdjustOperation; | |||||
| class RandomCropOperation; | class RandomCropOperation; | ||||
| class RandomHorizontalFlipOperation; | class RandomHorizontalFlipOperation; | ||||
| class RandomRotationOperation; | class RandomRotationOperation; | ||||
| class RandomSharpnessOperation; | |||||
| class RandomSolarizeOperation; | class RandomSolarizeOperation; | ||||
| class RandomVerticalFlipOperation; | class RandomVerticalFlipOperation; | ||||
| class ResizeOperation; | class ResizeOperation; | ||||
| @@ -209,6 +210,13 @@ std::shared_ptr<RandomRotationOperation> RandomRotation( | |||||
| std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false, | std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false, | ||||
| std::vector<float> center = {-1, -1}, std::vector<uint8_t> fill_value = {0, 0, 0}); | std::vector<float> center = {-1, -1}, std::vector<uint8_t> fill_value = {0, 0, 0}); | ||||
| /// \brief Function to create a RandomSharpness TensorOperation. | |||||
| /// \notes Tensor operation to perform random sharpness. | |||||
| /// \param[in] start_degree - float representing the start of the range to uniformly sample the factor from it. | |||||
| /// \param[in] end_degree - float representing the end of the range. | |||||
| /// \return Shared pointer to the current TensorOperation. | |||||
| std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees = {0.1, 1.9}); | |||||
| /// \brief Function to create a RandomSolarize TensorOperation. | /// \brief Function to create a RandomSolarize TensorOperation. | ||||
| /// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold | /// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold | ||||
| /// \param[in] threshold_min - lower limit | /// \param[in] threshold_min - lower limit | ||||
| @@ -468,6 +476,20 @@ class RandomRotationOperation : public TensorOperation { | |||||
| std::vector<uint8_t> fill_value_; | std::vector<uint8_t> fill_value_; | ||||
| }; | }; | ||||
| class RandomSharpnessOperation : public TensorOperation { | |||||
| public: | |||||
| explicit RandomSharpnessOperation(std::vector<float> degrees = {0.1, 1.9}); | |||||
| ~RandomSharpnessOperation() = default; | |||||
| std::shared_ptr<TensorOp> Build() override; | |||||
| bool ValidateParams() override; | |||||
| private: | |||||
| std::vector<float> degrees_; | |||||
| }; | |||||
| class RandomVerticalFlipOperation : public TensorOperation { | class RandomVerticalFlipOperation : public TensorOperation { | ||||
| public: | public: | ||||
| explicit RandomVerticalFlipOperation(float probability = 0.5); | explicit RandomVerticalFlipOperation(float probability = 0.5); | ||||
| @@ -32,9 +32,11 @@ add_library(kernels-image OBJECT | |||||
| random_solarize_op.cc | random_solarize_op.cc | ||||
| random_vertical_flip_op.cc | random_vertical_flip_op.cc | ||||
| random_vertical_flip_with_bbox_op.cc | random_vertical_flip_with_bbox_op.cc | ||||
| random_sharpness_op.cc | |||||
| rescale_op.cc | rescale_op.cc | ||||
| resize_bilinear_op.cc | resize_bilinear_op.cc | ||||
| resize_op.cc | resize_op.cc | ||||
| sharpness_op.cc | |||||
| solarize_op.cc | solarize_op.cc | ||||
| swap_red_blue_op.cc | swap_red_blue_op.cc | ||||
| uniform_aug_op.cc | uniform_aug_op.cc | ||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/random_sharpness_op.h" | |||||
| #include <random> | |||||
| #include "minddata/dataset/kernels/image/sharpness_op.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| const float RandomSharpnessOp::kDefStartDegree = 0.1; | |||||
| const float RandomSharpnessOp::kDefEndDegree = 1.9; | |||||
| /// constructor | |||||
| RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree) | |||||
| : start_degree_(start_degree), end_degree_(end_degree) { | |||||
| rnd_.seed(GetSeed()); | |||||
| } | |||||
| /// main function call for random sharpness : Generate the random degrees | |||||
| Status RandomSharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | |||||
| float random_double = distribution_(rnd_); | |||||
| /// get the degree sharpness range | |||||
| /// the way this op works (uniform distribution) | |||||
| /// assumption here is that mDegreesEnd > mDegreeStart so we always get positive number | |||||
| float degree_range = (end_degree_ - start_degree_) / 2; | |||||
| float mid = (end_degree_ + start_degree_) / 2; | |||||
| alpha_ = mid + random_double * degree_range; | |||||
| SharpnessOp::Compute(input, output); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "minddata/dataset/kernels/image/sharpness_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class RandomSharpnessOp : public SharpnessOp { | |||||
| public: | |||||
| static const float kDefStartDegree; | |||||
| static const float kDefEndDegree; | |||||
| /// Adjust the sharpness of the input image by a random degree within the given range. | |||||
| /// \@param[in] start_degree A float indicating the beginning of the range. | |||||
| /// \@param[in] end_degree A float indicating the end of the range. | |||||
| explicit RandomSharpnessOp(float start_degree = kDefStartDegree, const float end_degree = kDefEndDegree); | |||||
| ~RandomSharpnessOp() override = default; | |||||
| void Print(std::ostream &out) const override { out << Name(); } | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| std::string Name() const override { return kRandomSharpnessOp; } | |||||
| protected: | |||||
| float start_degree_; | |||||
| float end_degree_; | |||||
| std::uniform_real_distribution<float> distribution_{-1.0, 1.0}; | |||||
| std::mt19937 rnd_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SHARPNESS_OP_H_ | |||||
| @@ -0,0 +1,84 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/sharpness_op.h" | |||||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| const float SharpnessOp::kDefAlpha = 1.0; | |||||
| Status SharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | |||||
| try { | |||||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||||
| cv::Mat input_img = input_cv->mat(); | |||||
| if (!input_cv->mat().data) { | |||||
| RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); | |||||
| } | |||||
| if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { | |||||
| RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>"); | |||||
| } | |||||
| /// Get number of channels and image matrix | |||||
| std::size_t num_of_channels = input_cv->shape()[2]; | |||||
| if (num_of_channels != 1 && num_of_channels != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3."); | |||||
| } | |||||
| /// creating a smoothing filter. 1, 1, 1, | |||||
| /// 1, 5, 1, | |||||
| /// 1, 1, 1 | |||||
| float filterSum = 13.0; | |||||
| cv::Mat filter = cv::Mat(3, 3, CV_32F, cv::Scalar::all(1.0 / filterSum)); | |||||
| filter.at<float>(1, 1) = 5.0 / filterSum; | |||||
| /// applying filter on channels | |||||
| cv::Mat result = cv::Mat(); | |||||
| cv::filter2D(input_img, result, -1, filter); | |||||
| int height = input_cv->shape()[0]; | |||||
| int width = input_cv->shape()[1]; | |||||
| /// restoring the edges | |||||
| input_img.row(0).copyTo(result.row(0)); | |||||
| input_img.row(height - 1).copyTo(result.row(height - 1)); | |||||
| input_img.col(0).copyTo(result.col(0)); | |||||
| input_img.col(width - 1).copyTo(result.col(width - 1)); | |||||
| /// blend based on alpha : (alpha_ *input_img) + ((1.0-alpha_) * result); | |||||
| cv::addWeighted(input_img, alpha_, result, 1.0 - alpha_, 0.0, result); | |||||
| std::shared_ptr<CVTensor> output_cv; | |||||
| RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv)); | |||||
| RETURN_UNEXPECTED_IF_NULL(output_cv); | |||||
| *output = std::static_pointer_cast<Tensor>(output_cv); | |||||
| } | |||||
| catch (const cv::Exception &e) { | |||||
| RETURN_STATUS_UNEXPECTED("OpenCV error in random sharpness"); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/kernels/tensor_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class SharpnessOp : public TensorOp { | |||||
| public: | |||||
| /// Default values, also used by bindings.cc | |||||
| static const float kDefAlpha; | |||||
| /// This class can be used to adjust the sharpness of an image. | |||||
| /// \@param[in] alpha A float indicating the enhancement factor. | |||||
| /// a factor of 0.0 gives a blurred image, a factor of 1.0 gives the | |||||
| /// original image, and a factor of 2.0 gives a sharpened image. | |||||
| explicit SharpnessOp(const float alpha = kDefAlpha) : alpha_(alpha) {} | |||||
| ~SharpnessOp() override = default; | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| std::string Name() const override { return kSharpnessOp; } | |||||
| protected: | |||||
| float alpha_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_SHARPNESS_OP_H_ | |||||
| @@ -114,6 +114,7 @@ constexpr char kRandomResizeOp[] = "RandomResizeOp"; | |||||
| constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp"; | constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp"; | ||||
| constexpr char kRandomRotationOp[] = "RandomRotationOp"; | constexpr char kRandomRotationOp[] = "RandomRotationOp"; | ||||
| constexpr char kRandomSolarizeOp[] = "RandomSolarizeOp"; | constexpr char kRandomSolarizeOp[] = "RandomSolarizeOp"; | ||||
| constexpr char kRandomSharpnessOp[] = "RandomSharpnessOp"; | |||||
| constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp"; | constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp"; | ||||
| constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp"; | constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp"; | ||||
| constexpr char kRescaleOp[] = "RescaleOp"; | constexpr char kRescaleOp[] = "RescaleOp"; | ||||
| @@ -121,6 +122,7 @@ constexpr char kResizeBilinearOp[] = "ResizeBilinearOp"; | |||||
| constexpr char kResizeOp[] = "ResizeOp"; | constexpr char kResizeOp[] = "ResizeOp"; | ||||
| constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; | constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; | ||||
| constexpr char kSolarizeOp[] = "SolarizeOp"; | constexpr char kSolarizeOp[] = "SolarizeOp"; | ||||
| constexpr char kSharpnessOp[] = "SharpnessOp"; | |||||
| constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; | constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; | ||||
| constexpr char kUniformAugOp[] = "UniformAugOp"; | constexpr char kUniformAugOp[] = "UniformAugOp"; | ||||
| constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp"; | constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp"; | ||||
| @@ -48,7 +48,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec | |||||
| check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | ||||
| check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \ | check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \ | ||||
| check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ | check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ | ||||
| check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, FLOAT_MAX_INTEGER | |||||
| check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER | |||||
| DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, | DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, | ||||
| Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, | Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, | ||||
| @@ -90,6 +90,31 @@ class AutoContrast(cde.AutoContrastOp): | |||||
| super().__init__(cutoff, ignore) | super().__init__(cutoff, ignore) | ||||
| class RandomSharpness(cde.RandomSharpnessOp): | |||||
| """ | |||||
| Adjust the sharpness of the input image by a fixed or random degree. degree of 0.0 gives a blurred image, | |||||
| a degree of 1.0 gives the original image, and a degree of 2.0 gives a sharpened image. | |||||
| Args: | |||||
| degrees (sequence): Range of random sharpness adjustment degrees. | |||||
| it should be in (min, max) format. If min=max, then it is a | |||||
| single fixed magnitude operation (default = (0.1, 1.9)). | |||||
| Raises: | |||||
| TypeError : If degrees is not a list or tuple. | |||||
| ValueError: If degrees is not positive. | |||||
| ValueError: If degrees is in (max, min) format instead of (min, max). | |||||
| Examples: | |||||
| >>>c_transform.RandomSharpness(degrees=(0.2,1.9)) | |||||
| """ | |||||
| @check_positive_degrees | |||||
| def __init__(self, degrees=(0.1, 1.9)): | |||||
| self.degrees = degrees | |||||
| super().__init__(*degrees) | |||||
| class Equalize(cde.EqualizeOp): | class Equalize(cde.EqualizeOp): | ||||
| """ | """ | ||||
| Apply histogram equalization on input image. | Apply histogram equalization on input image. | ||||
| @@ -614,14 +614,16 @@ def check_positive_degrees(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| [degrees], _ = parse_user_args(method, *args, **kwargs) | [degrees], _ = parse_user_args(method, *args, **kwargs) | ||||
| if isinstance(degrees, (list, tuple)): | if isinstance(degrees, (list, tuple)): | ||||
| if len(degrees) != 2: | if len(degrees) != 2: | ||||
| raise ValueError("Degrees must be a sequence with length 2.") | raise ValueError("Degrees must be a sequence with length 2.") | ||||
| for value in degrees: | |||||
| check_value(value, (0., FLOAT_MAX_INTEGER)) | |||||
| check_positive(degrees[0], "degrees[0]") | check_positive(degrees[0], "degrees[0]") | ||||
| if degrees[0] > degrees[1]: | if degrees[0] > degrees[1]: | ||||
| raise ValueError("Degrees should be in (min,max) format. Got (max,min).") | raise ValueError("Degrees should be in (min,max) format. Got (max,min).") | ||||
| else: | |||||
| raise TypeError("Degrees should be a tuple or list.") | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -34,12 +34,12 @@ | |||||
| #include "minddata/dataset/include/samplers.h" | #include "minddata/dataset/include/samplers.h" | ||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::MsLogLevel::ERROR; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | using mindspore::LogStream; | ||||
| using mindspore::dataset::Tensor; | |||||
| using mindspore::dataset::Status; | |||||
| using mindspore::dataset::BorderType; | using mindspore::dataset::BorderType; | ||||
| using mindspore::dataset::Status; | |||||
| using mindspore::dataset::Tensor; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::ERROR; | |||||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | class MindDataTestPipeline : public UT::DatasetOpTesting { | ||||
| protected: | protected: | ||||
| @@ -308,10 +308,10 @@ TEST_F(MindDataTestPipeline, TestPad) { | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | } | ||||
| EXPECT_EQ(i, 20); | EXPECT_EQ(i, 20); | ||||
| @@ -358,10 +358,10 @@ TEST_F(MindDataTestPipeline, TestCutOut) { | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | } | ||||
| EXPECT_EQ(i, 20); | EXPECT_EQ(i, 20); | ||||
| @@ -527,12 +527,12 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { | |||||
| std::shared_ptr<TensorOperation> random_color_adjust1 = vision::RandomColorAdjust({1.0}, {0.0}, {0.5}, {0.5}); | std::shared_ptr<TensorOperation> random_color_adjust1 = vision::RandomColorAdjust({1.0}, {0.0}, {0.5}, {0.5}); | ||||
| EXPECT_NE(random_color_adjust1, nullptr); | EXPECT_NE(random_color_adjust1, nullptr); | ||||
| std::shared_ptr<TensorOperation> random_color_adjust2 = vision::RandomColorAdjust({1.0, 1.0}, {0.0, 0.0}, {0.5, 0.5}, | |||||
| {0.5, 0.5}); | |||||
| std::shared_ptr<TensorOperation> random_color_adjust2 = | |||||
| vision::RandomColorAdjust({1.0, 1.0}, {0.0, 0.0}, {0.5, 0.5}, {0.5, 0.5}); | |||||
| EXPECT_NE(random_color_adjust2, nullptr); | EXPECT_NE(random_color_adjust2, nullptr); | ||||
| std::shared_ptr<TensorOperation> random_color_adjust3 = vision::RandomColorAdjust({0.5, 1.0}, {0.0, 0.5}, {0.25, 0.5}, | |||||
| {0.25, 0.5}); | |||||
| std::shared_ptr<TensorOperation> random_color_adjust3 = | |||||
| vision::RandomColorAdjust({0.5, 1.0}, {0.0, 0.5}, {0.25, 0.5}, {0.25, 0.5}); | |||||
| EXPECT_NE(random_color_adjust3, nullptr); | EXPECT_NE(random_color_adjust3, nullptr); | ||||
| std::shared_ptr<TensorOperation> random_color_adjust4 = vision::RandomColorAdjust(); | std::shared_ptr<TensorOperation> random_color_adjust4 = vision::RandomColorAdjust(); | ||||
| @@ -558,10 +558,68 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| EXPECT_EQ(i, 20); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestRandomSharpness) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSharpness."; | |||||
| // Create an ImageFolder Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create a Repeat operation on ds | |||||
| int32_t repeat_num = 2; | |||||
| ds = ds->Repeat(repeat_num); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create objects for the tensor ops | |||||
| std::shared_ptr<TensorOperation> random_sharpness_op_1 = vision::RandomSharpness({0.4, 2.3}); | |||||
| EXPECT_NE(random_sharpness_op_1, nullptr); | |||||
| std::shared_ptr<TensorOperation> random_sharpness_op_2 = vision::RandomSharpness({}); | |||||
| EXPECT_EQ(random_sharpness_op_2, nullptr); | |||||
| std::shared_ptr<TensorOperation> random_sharpness_op_3 = vision::RandomSharpness(); | |||||
| EXPECT_NE(random_sharpness_op_3, nullptr); | |||||
| std::shared_ptr<TensorOperation> random_sharpness_op_4 = vision::RandomSharpness({0.1}); | |||||
| EXPECT_EQ(random_sharpness_op_4, nullptr); | |||||
| // Create a Map operation on ds | |||||
| ds = ds->Map({random_sharpness_op_1, random_sharpness_op_3}); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create a Batch operation on ds | |||||
| int32_t batch_size = 1; | |||||
| ds = ds->Batch(batch_size); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | } | ||||
| EXPECT_EQ(i, 20); | EXPECT_EQ(i, 20); | ||||
| @@ -146,6 +146,14 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te | |||||
| expect_image_path = dir_path + "imagefolder/apple_expect_random_solarize.jpg"; | expect_image_path = dir_path + "imagefolder/apple_expect_random_solarize.jpg"; | ||||
| actual_image_path = dir_path + "imagefolder/apple_actual_random_solarize.jpg"; | actual_image_path = dir_path + "imagefolder/apple_actual_random_solarize.jpg"; | ||||
| break; | break; | ||||
| case kInvert: | |||||
| expect_image_path = dir_path + "imagefolder/apple_expect_invert.jpg"; | |||||
| actual_image_path = dir_path + "imagefolder/apple_actual_invert.jpg"; | |||||
| break; | |||||
| case kRandomSharpness: | |||||
| expect_image_path = dir_path + "imagefolder/apple_expect_random_sharpness.jpg"; | |||||
| actual_image_path = dir_path + "imagefolder/apple_actual_random_sharpness.jpg"; | |||||
| break; | |||||
| default: | default: | ||||
| MS_LOG(INFO) << "Not pass verification! Operation type does not exists."; | MS_LOG(INFO) << "Not pass verification! Operation type does not exists."; | ||||
| EXPECT_EQ(0, 1); | EXPECT_EQ(0, 1); | ||||
| @@ -39,6 +39,8 @@ class CVOpCommon : public Common { | |||||
| kRandomSolarize, | kRandomSolarize, | ||||
| kTemplate, | kTemplate, | ||||
| kCrop, | kCrop, | ||||
| kRandomSharpness, | |||||
| kInvert, | |||||
| kRandomAffine, | kRandomAffine, | ||||
| kAutoContrast, | kAutoContrast, | ||||
| kEqualize | kEqualize | ||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/invert_op.h" | |||||
| #include "common/common.h" | |||||
| #include "common/cvop_common.h" | |||||
| #include "utils/log_adapter.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestInvert : public UT::CVOP::CVOpCommon { | |||||
| public: | |||||
| MindDataTestInvert() : CVOpCommon() {} | |||||
| }; | |||||
| TEST_F(MindDataTestInvert, TestOp) { | |||||
| MS_LOG(INFO) << "Doing test Invert."; | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| std::unique_ptr<InvertOp> op(new InvertOp()); | |||||
| EXPECT_TRUE(op->OneToOne()); | |||||
| Status st = op->Compute(input_tensor_, &output_tensor); | |||||
| EXPECT_TRUE(st.IsOk()); | |||||
| CheckImageShapeAndData(output_tensor, kInvert); | |||||
| MS_LOG(INFO) << "testInvert end."; | |||||
| } | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/random_sharpness_op.h" | |||||
| #include "common/common.h" | |||||
| #include "common/cvop_common.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/core/global_context.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestRandomSharpness : public UT::CVOP::CVOpCommon { | |||||
| public: | |||||
| MindDataTestRandomSharpness() : CVOpCommon() {} | |||||
| }; | |||||
| TEST_F(MindDataTestRandomSharpness, TestOp) { | |||||
| MS_LOG(INFO) << "Doing test RandomSharpness."; | |||||
| // setting seed here | |||||
| u_int32_t curr_seed = GlobalContext::config_manager()->seed(); | |||||
| GlobalContext::config_manager()->set_seed(120); | |||||
| // Sharpness with a factor in range [0.2,1.8] | |||||
| float start_degree = 0.2; | |||||
| float end_degree = 1.8; | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| // sharpening | |||||
| std::unique_ptr<RandomSharpnessOp> op(new RandomSharpnessOp(start_degree, end_degree)); | |||||
| EXPECT_TRUE(op->OneToOne()); | |||||
| Status st = op->Compute(input_tensor_, &output_tensor); | |||||
| EXPECT_TRUE(st.IsOk()); | |||||
| CheckImageShapeAndData(output_tensor, kRandomSharpness); | |||||
| // restoring the seed | |||||
| GlobalContext::config_manager()->set_seed(curr_seed); | |||||
| MS_LOG(INFO) << "testRandomSharpness end."; | |||||
| } | |||||
| @@ -19,20 +19,22 @@ import numpy as np | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| import mindspore.dataset.transforms.vision.py_transforms as F | import mindspore.dataset.transforms.vision.py_transforms as F | ||||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import visualize_list, diff_mse, save_and_check_md5, \ | |||||
| from util import visualize_list, visualize_one_channel_dataset, diff_mse, save_and_check_md5, \ | |||||
| config_get_set_seed, config_get_set_num_parallel_workers | config_get_set_seed, config_get_set_num_parallel_workers | ||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | DATA_DIR = "../data/dataset/testImageNetData/train/" | ||||
| MNIST_DATA_DIR = "../data/dataset/testMnistData" | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| def test_random_sharpness(degrees=(0.1, 1.9), plot=False): | |||||
| def test_random_sharpness_py(degrees=(0.7, 0.7), plot=False): | |||||
| """ | """ | ||||
| Test RandomSharpness | |||||
| Test RandomSharpness python op | |||||
| """ | """ | ||||
| logger.info("Test RandomSharpness") | |||||
| logger.info("Test RandomSharpness python op") | |||||
| # Original Images | # Original Images | ||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ||||
| @@ -54,12 +56,16 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False): | |||||
| np.transpose(image, (0, 2, 3, 1)), | np.transpose(image, (0, 2, 3, 1)), | ||||
| axis=0) | axis=0) | ||||
| # Random Sharpness Adjusted Images | |||||
| # Random Sharpness Adjusted Images | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ||||
| py_op = F.RandomSharpness() | |||||
| if degrees is not None: | |||||
| py_op = F.RandomSharpness(degrees) | |||||
| transforms_random_sharpness = F.ComposeOp([F.Decode(), | transforms_random_sharpness = F.ComposeOp([F.Decode(), | ||||
| F.Resize((224, 224)), | F.Resize((224, 224)), | ||||
| F.RandomSharpness(degrees=degrees), | |||||
| py_op, | |||||
| F.ToTensor()]) | F.ToTensor()]) | ||||
| ds_random_sharpness = data.map(input_columns="image", | ds_random_sharpness = data.map(input_columns="image", | ||||
| @@ -86,11 +92,11 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False): | |||||
| visualize_list(images_original, images_random_sharpness) | visualize_list(images_original, images_random_sharpness) | ||||
| def test_random_sharpness_md5(): | |||||
| def test_random_sharpness_py_md5(): | |||||
| """ | """ | ||||
| Test RandomSharpness with md5 comparison | |||||
| Test RandomSharpness python op with md5 comparison | |||||
| """ | """ | ||||
| logger.info("Test RandomSharpness with md5 comparison") | |||||
| logger.info("Test RandomSharpness python op with md5 comparison") | |||||
| original_seed = config_get_set_seed(5) | original_seed = config_get_set_seed(5) | ||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| @@ -107,7 +113,7 @@ def test_random_sharpness_md5(): | |||||
| data = data.map(input_columns=["image"], operations=transform()) | data = data.map(input_columns=["image"], operations=transform()) | ||||
| # check results with md5 comparison | # check results with md5 comparison | ||||
| filename = "random_sharpness_01_result.npz" | |||||
| filename = "random_sharpness_py_01_result.npz" | |||||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | ||||
| # Restore configuration | # Restore configuration | ||||
| @@ -115,8 +121,230 @@ def test_random_sharpness_md5(): | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | ds.config.set_num_parallel_workers(original_num_parallel_workers) | ||||
| def test_random_sharpness_c(degrees=(1.6, 1.6), plot=False): | |||||
| """ | |||||
| Test RandomSharpness cpp op | |||||
| """ | |||||
| print(degrees) | |||||
| logger.info("Test RandomSharpness cpp op") | |||||
| # Original Images | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = [C.Decode(), | |||||
| C.Resize((224, 224))] | |||||
| ds_original = data.map(input_columns="image", | |||||
| operations=transforms_original) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = image | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| image, | |||||
| axis=0) | |||||
| # Random Sharpness Adjusted Images | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| c_op = C.RandomSharpness() | |||||
| if degrees is not None: | |||||
| c_op = C.RandomSharpness(degrees) | |||||
| transforms_random_sharpness = [C.Decode(), | |||||
| C.Resize((224, 224)), | |||||
| c_op] | |||||
| ds_random_sharpness = data.map(input_columns="image", | |||||
| operations=transforms_random_sharpness) | |||||
| ds_random_sharpness = ds_random_sharpness.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_random_sharpness): | |||||
| if idx == 0: | |||||
| images_random_sharpness = image | |||||
| else: | |||||
| images_random_sharpness = np.append(images_random_sharpness, | |||||
| image, | |||||
| axis=0) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = diff_mse(images_random_sharpness[i], images_original[i]) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize_list(images_original, images_random_sharpness) | |||||
| def test_random_sharpness_c_md5(): | |||||
| """ | |||||
| Test RandomSharpness cpp op with md5 comparison | |||||
| """ | |||||
| logger.info("Test RandomSharpness cpp op with md5 comparison") | |||||
| original_seed = config_get_set_seed(200) | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| # define map operations | |||||
| transforms = [ | |||||
| C.Decode(), | |||||
| C.RandomSharpness((0.1, 1.9)) | |||||
| ] | |||||
| # Generate dataset | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| data = data.map(input_columns=["image"], operations=transforms) | |||||
| # check results with md5 comparison | |||||
| filename = "random_sharpness_cpp_01_result.npz" | |||||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||||
| # Restore configuration | |||||
| ds.config.set_seed(original_seed) | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_random_sharpness_c_py(degrees=(1.0, 1.0), plot=False): | |||||
| """ | |||||
| Test Random Sharpness C and python Op | |||||
| """ | |||||
| logger.info("Test RandomSharpness C and python Op") | |||||
| # RandomSharpness Images | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| data = data.map(input_columns=["image"], | |||||
| operations=[C.Decode(), | |||||
| C.Resize((200, 300))]) | |||||
| python_op = F.RandomSharpness(degrees) | |||||
| c_op = C.RandomSharpness(degrees) | |||||
| transforms_op = F.ComposeOp([lambda img: F.ToPIL()(img.astype(np.uint8)), | |||||
| python_op, | |||||
| np.array])() | |||||
| ds_random_sharpness_py = data.map(input_columns="image", | |||||
| operations=transforms_op) | |||||
| ds_random_sharpness_py = ds_random_sharpness_py.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_random_sharpness_py): | |||||
| if idx == 0: | |||||
| images_random_sharpness_py = image | |||||
| else: | |||||
| images_random_sharpness_py = np.append(images_random_sharpness_py, | |||||
| image, | |||||
| axis=0) | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| data = data.map(input_columns=["image"], | |||||
| operations=[C.Decode(), | |||||
| C.Resize((200, 300))]) | |||||
| ds_images_random_sharpness_c = data.map(input_columns="image", | |||||
| operations=c_op) | |||||
| ds_images_random_sharpness_c = ds_images_random_sharpness_c.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_images_random_sharpness_c): | |||||
| if idx == 0: | |||||
| images_random_sharpness_c = image | |||||
| else: | |||||
| images_random_sharpness_c = np.append(images_random_sharpness_c, | |||||
| image, | |||||
| axis=0) | |||||
| num_samples = images_random_sharpness_c.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = diff_mse(images_random_sharpness_c[i], images_random_sharpness_py[i]) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize_list(images_random_sharpness_c, images_random_sharpness_py, visualize_mode=2) | |||||
| def test_random_sharpness_one_channel_c(degrees=(1.4, 1.4), plot=False): | |||||
| """ | |||||
| Test Random Sharpness cpp op with one channel | |||||
| """ | |||||
| logger.info("Test RandomSharpness C Op With MNIST Dataset (Grayscale images)") | |||||
| c_op = C.RandomSharpness() | |||||
| if degrees is not None: | |||||
| c_op = C.RandomSharpness(degrees) | |||||
| # RandomSharpness Images | |||||
| data = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) | |||||
| ds_random_sharpness_c = data.map(input_columns="image", operations=c_op) | |||||
| # Original images | |||||
| data = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) | |||||
| images = [] | |||||
| images_trans = [] | |||||
| labels = [] | |||||
| for _, (data_orig, data_trans) in enumerate(zip(data, ds_random_sharpness_c)): | |||||
| image_orig, label_orig = data_orig | |||||
| image_trans, _ = data_trans | |||||
| images.append(image_orig) | |||||
| labels.append(label_orig) | |||||
| images_trans.append(image_trans) | |||||
| if plot: | |||||
| visualize_one_channel_dataset(images, images_trans, labels) | |||||
| def test_random_sharpness_invalid_params(): | |||||
| """ | |||||
| Test RandomSharpness with invalid input parameters. | |||||
| """ | |||||
| logger.info("Test RandomSharpness with invalid input parameters.") | |||||
| try: | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| data = data.map(input_columns=["image"], | |||||
| operations=[C.Decode(), | |||||
| C.Resize((224, 224)), | |||||
| C.RandomSharpness(10)]) | |||||
| except TypeError as error: | |||||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||||
| assert "tuple" in str(error) | |||||
| try: | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| data = data.map(input_columns=["image"], | |||||
| operations=[C.Decode(), | |||||
| C.Resize((224, 224)), | |||||
| C.RandomSharpness((-10, 10))]) | |||||
| except ValueError as error: | |||||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||||
| assert "interval" in str(error) | |||||
| try: | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| data = data.map(input_columns=["image"], | |||||
| operations=[C.Decode(), | |||||
| C.Resize((224, 224)), | |||||
| C.RandomSharpness((10, 5))]) | |||||
| except ValueError as error: | |||||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||||
| assert "(min,max)" in str(error) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_random_sharpness() | |||||
| test_random_sharpness(plot=True) | |||||
| test_random_sharpness(degrees=(0.5, 1.5), plot=True) | |||||
| test_random_sharpness_md5() | |||||
| test_random_sharpness_py(plot=True) | |||||
| test_random_sharpness_py(None, plot=True) # test with default values | |||||
| test_random_sharpness_py_md5() | |||||
| test_random_sharpness_c(plot=True) | |||||
| test_random_sharpness_c(None, plot=True) # test with default values | |||||
| test_random_sharpness_c_md5() | |||||
| test_random_sharpness_c_py(degrees=[1.5, 1.5], plot=True) | |||||
| test_random_sharpness_c_py(degrees=[1, 1], plot=True) | |||||
| test_random_sharpness_c_py(degrees=[10, 10], plot=True) | |||||
| test_random_sharpness_one_channel_c(degrees=[1.7, 1.7], plot=True) | |||||
| test_random_sharpness_one_channel_c(degrees=None, plot=True) # test with default values | |||||
| test_random_sharpness_invalid_params() | |||||