Merge pull request !3255 from alashkari/cpp_opstags/v0.6.0-beta
| @@ -54,6 +54,7 @@ | |||||
| #include "minddata/dataset/kernels/image/center_crop_op.h" | #include "minddata/dataset/kernels/image/center_crop_op.h" | ||||
| #include "minddata/dataset/kernels/image/cut_out_op.h" | #include "minddata/dataset/kernels/image/cut_out_op.h" | ||||
| #include "minddata/dataset/kernels/image/decode_op.h" | #include "minddata/dataset/kernels/image/decode_op.h" | ||||
| #include "minddata/dataset/kernels/image/equalize_op.h" | |||||
| #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" | #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" | ||||
| #include "minddata/dataset/kernels/image/image_utils.h" | #include "minddata/dataset/kernels/image/image_utils.h" | ||||
| #include "minddata/dataset/kernels/image/invert_op.h" | #include "minddata/dataset/kernels/image/invert_op.h" | ||||
| @@ -389,6 +390,10 @@ void bindTensorOps1(py::module *m) { | |||||
| .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), | .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), | ||||
| py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); | py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); | ||||
| (void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>( | |||||
| *m, "EqualizeOp", "Tensor operation to apply histogram equalization on images.") | |||||
| .def(py::init<>()); | |||||
| (void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp", | (void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp", | ||||
| "Tensor operation to apply invert on RGB images.") | "Tensor operation to apply invert on RGB images.") | ||||
| .def(py::init<>()); | .def(py::init<>()); | ||||
| @@ -5,6 +5,7 @@ add_library(kernels-image OBJECT | |||||
| center_crop_op.cc | center_crop_op.cc | ||||
| cut_out_op.cc | cut_out_op.cc | ||||
| decode_op.cc | decode_op.cc | ||||
| equalize_op.cc | |||||
| hwc_to_chw_op.cc | hwc_to_chw_op.cc | ||||
| image_utils.cc | image_utils.cc | ||||
| invert_op.cc | invert_op.cc | ||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/equalize_op.h" | |||||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // only supports RGB images | |||||
| Status EqualizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | |||||
| return Equalize(input, output); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/kernels/tensor_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class EqualizeOp : public TensorOp { | |||||
| public: | |||||
| EqualizeOp() {} | |||||
| ~EqualizeOp() = default; | |||||
| // Description: A function that prints info about the node | |||||
| void Print(std::ostream &out) const override { out << Name(); } | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| std::string Name() const override { return kEqualizeOp; } | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ | |||||
| @@ -749,6 +749,46 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| try { | |||||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||||
| if (!input_cv->mat().data) { | |||||
| RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); | |||||
| } | |||||
| if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { | |||||
| RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>"); | |||||
| } | |||||
| // For greyscale images, extend dimension if rank is 2 and reshape output to be of rank 2. | |||||
| if (input_cv->Rank() == 2) { | |||||
| RETURN_IF_NOT_OK(input_cv->ExpandDim(2)); | |||||
| } | |||||
| // Get number of channels and image matrix | |||||
| std::size_t num_of_channels = input_cv->shape()[2]; | |||||
| if (num_of_channels != 1 && num_of_channels != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3."); | |||||
| } | |||||
| cv::Mat image = input_cv->mat(); | |||||
| // Separate the image to channels | |||||
| std::vector<cv::Mat> planes(num_of_channels); | |||||
| cv::split(image, planes); | |||||
| // Equalize each channel separately | |||||
| std::vector<cv::Mat> image_result; | |||||
| for (std::size_t layer = 0; layer < planes.size(); layer++) { | |||||
| cv::Mat channel_result; | |||||
| cv::equalizeHist(planes[layer], channel_result); | |||||
| image_result.push_back(channel_result); | |||||
| } | |||||
| cv::Mat result; | |||||
| cv::merge(image_result, result); | |||||
| std::shared_ptr<CVTensor> output_cv = std::make_shared<CVTensor>(result); | |||||
| if (input_cv->Rank() == 2) output_cv->Squeeze(); | |||||
| (*output) = std::static_pointer_cast<Tensor>(output_cv); | |||||
| } catch (const cv::Exception &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Error in equalize."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height, | Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height, | ||||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, | int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, | ||||
| uint8_t fill_g, uint8_t fill_b) { | uint8_t fill_g, uint8_t fill_b) { | ||||
| @@ -200,6 +200,12 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te | |||||
| // @param output: Adjusted image of same shape and type. | // @param output: Adjusted image of same shape and type. | ||||
| Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &hue); | Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &hue); | ||||
| /// \brief Returns image with equalized histogram. | |||||
| /// \param[in] input: Tensor of shape <H,W,3>/<H,W,1>/<H,W> in RGB/Grayscale and | |||||
| /// any OpenCv compatible type, see CVTensor. | |||||
| /// \param[out] output: Equalized image of same shape and type. | |||||
| Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output); | |||||
| // Masks out a random section from the image with set dimension | // Masks out a random section from the image with set dimension | ||||
| // @param input: input Tensor | // @param input: input Tensor | ||||
| // @param output: cutOut Tensor | // @param output: cutOut Tensor | ||||
| @@ -92,6 +92,7 @@ constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; | |||||
| constexpr char kDecodeOp[] = "DecodeOp"; | constexpr char kDecodeOp[] = "DecodeOp"; | ||||
| constexpr char kCenterCropOp[] = "CenterCropOp"; | constexpr char kCenterCropOp[] = "CenterCropOp"; | ||||
| constexpr char kCutOutOp[] = "CutOutOp"; | constexpr char kCutOutOp[] = "CutOutOp"; | ||||
| constexpr char kEqualizeOp[] = "EqualizeOp"; | |||||
| constexpr char kHwcToChwOp[] = "HwcToChwOp"; | constexpr char kHwcToChwOp[] = "HwcToChwOp"; | ||||
| constexpr char kInvertOp[] = "InvertOp"; | constexpr char kInvertOp[] = "InvertOp"; | ||||
| constexpr char kNormalizeOp[] = "NormalizeOp"; | constexpr char kNormalizeOp[] = "NormalizeOp"; | ||||
| @@ -89,6 +89,13 @@ class AutoContrast(cde.AutoContrastOp): | |||||
| super().__init__(cutoff, ignore) | super().__init__(cutoff, ignore) | ||||
| class Equalize(cde.EqualizeOp): | |||||
| """ | |||||
| Apply histogram equalization on input image. | |||||
| does not have input arguments. | |||||
| """ | |||||
| class Invert(cde.InvertOp): | class Invert(cde.InvertOp): | ||||
| """ | """ | ||||
| Apply invert on input image in RGB mode. | Apply invert on input image in RGB mode. | ||||
| @@ -18,6 +18,7 @@ Testing Equalize op in DE | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | import mindspore.dataset.transforms.vision.py_transforms as F | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import visualize_list, diff_mse, save_and_check_md5 | from util import visualize_list, diff_mse, save_and_check_md5 | ||||
| @@ -26,9 +27,9 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| def test_equalize(plot=False): | |||||
| def test_equalize_py(plot=False): | |||||
| """ | """ | ||||
| Test Equalize | |||||
| Test Equalize py op | |||||
| """ | """ | ||||
| logger.info("Test Equalize") | logger.info("Test Equalize") | ||||
| @@ -83,9 +84,141 @@ def test_equalize(plot=False): | |||||
| visualize_list(images_original, images_equalize) | visualize_list(images_original, images_equalize) | ||||
| def test_equalize_md5(): | |||||
| def test_equalize_c(plot=False): | |||||
| """ | """ | ||||
| Test Equalize with md5 check | |||||
| Test Equalize Cpp op | |||||
| """ | |||||
| logger.info("Test Equalize cpp op") | |||||
| # Original Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = [C.Decode(), C.Resize(size=[224, 224])] | |||||
| ds_original = ds.map(input_columns="image", | |||||
| operations=transforms_original) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = image | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| image, | |||||
| axis=0) | |||||
| # Equalize Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transform_equalize = [C.Decode(), C.Resize(size=[224, 224]), | |||||
| C.Equalize()] | |||||
| ds_equalize = ds.map(input_columns="image", | |||||
| operations=transform_equalize) | |||||
| ds_equalize = ds_equalize.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_equalize): | |||||
| if idx == 0: | |||||
| images_equalize = image | |||||
| else: | |||||
| images_equalize = np.append(images_equalize, | |||||
| image, | |||||
| axis=0) | |||||
| if plot: | |||||
| visualize_list(images_original, images_equalize) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = diff_mse(images_equalize[i], images_original[i]) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| def test_equalize_py_c(plot=False): | |||||
| """ | |||||
| Test Equalize Cpp op and python op | |||||
| """ | |||||
| logger.info("Test Equalize cpp and python op") | |||||
| # equalize Images in cpp | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| ds = ds.map(input_columns=["image"], | |||||
| operations=[C.Decode(), C.Resize((224, 224))]) | |||||
| ds_c_equalize = ds.map(input_columns="image", | |||||
| operations=C.Equalize()) | |||||
| ds_c_equalize = ds_c_equalize.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_c_equalize): | |||||
| if idx == 0: | |||||
| images_c_equalize = image | |||||
| else: | |||||
| images_c_equalize = np.append(images_c_equalize, | |||||
| image, | |||||
| axis=0) | |||||
| # Equalize images in python | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| ds = ds.map(input_columns=["image"], | |||||
| operations=[C.Decode(), C.Resize((224, 224))]) | |||||
| transforms_p_equalize = F.ComposeOp([lambda img: img.astype(np.uint8), | |||||
| F.ToPIL(), | |||||
| F.Equalize(), | |||||
| np.array]) | |||||
| ds_p_equalize = ds.map(input_columns="image", | |||||
| operations=transforms_p_equalize()) | |||||
| ds_p_equalize = ds_p_equalize.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_p_equalize): | |||||
| if idx == 0: | |||||
| images_p_equalize = image | |||||
| else: | |||||
| images_p_equalize = np.append(images_p_equalize, | |||||
| image, | |||||
| axis=0) | |||||
| num_samples = images_c_equalize.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = diff_mse(images_p_equalize[i], images_c_equalize[i]) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize_list(images_c_equalize, images_p_equalize, visualize_mode=2) | |||||
| def test_equalize_one_channel(): | |||||
| """ | |||||
| Test Equalize cpp op with one channel image | |||||
| """ | |||||
| logger.info("Test Equalize C Op With One Channel Images") | |||||
| c_op = C.Equalize() | |||||
| try: | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| ds = ds.map(input_columns=["image"], | |||||
| operations=[C.Decode(), | |||||
| C.Resize((224, 224)), | |||||
| lambda img: np.array(img[:, :, 0])]) | |||||
| ds.map(input_columns="image", | |||||
| operations=c_op) | |||||
| except RuntimeError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "The shape" in str(e) | |||||
| def test_equalize_md5_py(): | |||||
| """ | |||||
| Test Equalize py op with md5 check | |||||
| """ | """ | ||||
| logger.info("Test Equalize") | logger.info("Test Equalize") | ||||
| @@ -101,6 +234,31 @@ def test_equalize_md5(): | |||||
| save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) | save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) | ||||
| def test_equalize_md5_c(): | |||||
| """ | |||||
| Test Equalize cpp op with md5 check | |||||
| """ | |||||
| logger.info("Test Equalize cpp op with md5 check") | |||||
| # Generate dataset | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_equalize = [C.Decode(), | |||||
| C.Resize(size=[224, 224]), | |||||
| C.Equalize(), | |||||
| F.ToTensor()] | |||||
| data = ds.map(input_columns="image", operations=transforms_equalize) | |||||
| # Compare with expected md5 from images | |||||
| filename = "equalize_01_result_c.npz" | |||||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_equalize(plot=True) | |||||
| test_equalize_md5() | |||||
| test_equalize_py(plot=False) | |||||
| test_equalize_c(plot=False) | |||||
| test_equalize_py_c(plot=False) | |||||
| test_equalize_one_channel() | |||||
| test_equalize_md5_py() | |||||
| test_equalize_md5_c() | |||||