Merge pull request !3255 from alashkari/cpp_opstags/v0.6.0-beta
| @@ -54,6 +54,7 @@ | |||
| #include "minddata/dataset/kernels/image/center_crop_op.h" | |||
| #include "minddata/dataset/kernels/image/cut_out_op.h" | |||
| #include "minddata/dataset/kernels/image/decode_op.h" | |||
| #include "minddata/dataset/kernels/image/equalize_op.h" | |||
| #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/kernels/image/invert_op.h" | |||
| @@ -389,6 +390,10 @@ void bindTensorOps1(py::module *m) { | |||
| .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), | |||
| py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); | |||
| (void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>( | |||
| *m, "EqualizeOp", "Tensor operation to apply histogram equalization on images.") | |||
| .def(py::init<>()); | |||
| (void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp", | |||
| "Tensor operation to apply invert on RGB images.") | |||
| .def(py::init<>()); | |||
| @@ -5,6 +5,7 @@ add_library(kernels-image OBJECT | |||
| center_crop_op.cc | |||
| cut_out_op.cc | |||
| decode_op.cc | |||
| equalize_op.cc | |||
| hwc_to_chw_op.cc | |||
| image_utils.cc | |||
| invert_op.cc | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/kernels/image/equalize_op.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // only supports RGB images | |||
| Status EqualizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| return Equalize(input, output); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class EqualizeOp : public TensorOp { | |||
| public: | |||
| EqualizeOp() {} | |||
| ~EqualizeOp() = default; | |||
| // Description: A function that prints info about the node | |||
| void Print(std::ostream &out) const override { out << Name(); } | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| std::string Name() const override { return kEqualizeOp; } | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ | |||
| @@ -749,6 +749,46 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||
| return Status::OK(); | |||
| } | |||
| Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| try { | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||
| if (!input_cv->mat().data) { | |||
| RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); | |||
| } | |||
| if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { | |||
| RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>"); | |||
| } | |||
| // For greyscale images, extend dimension if rank is 2 and reshape output to be of rank 2. | |||
| if (input_cv->Rank() == 2) { | |||
| RETURN_IF_NOT_OK(input_cv->ExpandDim(2)); | |||
| } | |||
| // Get number of channels and image matrix | |||
| std::size_t num_of_channels = input_cv->shape()[2]; | |||
| if (num_of_channels != 1 && num_of_channels != 3) { | |||
| RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3."); | |||
| } | |||
| cv::Mat image = input_cv->mat(); | |||
| // Separate the image to channels | |||
| std::vector<cv::Mat> planes(num_of_channels); | |||
| cv::split(image, planes); | |||
| // Equalize each channel separately | |||
| std::vector<cv::Mat> image_result; | |||
| for (std::size_t layer = 0; layer < planes.size(); layer++) { | |||
| cv::Mat channel_result; | |||
| cv::equalizeHist(planes[layer], channel_result); | |||
| image_result.push_back(channel_result); | |||
| } | |||
| cv::Mat result; | |||
| cv::merge(image_result, result); | |||
| std::shared_ptr<CVTensor> output_cv = std::make_shared<CVTensor>(result); | |||
| if (input_cv->Rank() == 2) output_cv->Squeeze(); | |||
| (*output) = std::static_pointer_cast<Tensor>(output_cv); | |||
| } catch (const cv::Exception &e) { | |||
| RETURN_STATUS_UNEXPECTED("Error in equalize."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height, | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, | |||
| uint8_t fill_g, uint8_t fill_b) { | |||
| @@ -200,6 +200,12 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te | |||
| // @param output: Adjusted image of same shape and type. | |||
| Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &hue); | |||
| /// \brief Returns image with equalized histogram. | |||
| /// \param[in] input: Tensor of shape <H,W,3>/<H,W,1>/<H,W> in RGB/Grayscale and | |||
| /// any OpenCv compatible type, see CVTensor. | |||
| /// \param[out] output: Equalized image of same shape and type. | |||
| Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output); | |||
| // Masks out a random section from the image with set dimension | |||
| // @param input: input Tensor | |||
| // @param output: cutOut Tensor | |||
| @@ -92,6 +92,7 @@ constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; | |||
| constexpr char kDecodeOp[] = "DecodeOp"; | |||
| constexpr char kCenterCropOp[] = "CenterCropOp"; | |||
| constexpr char kCutOutOp[] = "CutOutOp"; | |||
| constexpr char kEqualizeOp[] = "EqualizeOp"; | |||
| constexpr char kHwcToChwOp[] = "HwcToChwOp"; | |||
| constexpr char kInvertOp[] = "InvertOp"; | |||
| constexpr char kNormalizeOp[] = "NormalizeOp"; | |||
| @@ -89,6 +89,13 @@ class AutoContrast(cde.AutoContrastOp): | |||
| super().__init__(cutoff, ignore) | |||
| class Equalize(cde.EqualizeOp): | |||
| """ | |||
| Apply histogram equalization on input image. | |||
| does not have input arguments. | |||
| """ | |||
| class Invert(cde.InvertOp): | |||
| """ | |||
| Apply invert on input image in RGB mode. | |||
| @@ -18,6 +18,7 @@ Testing Equalize op in DE | |||
| import numpy as np | |||
| import mindspore.dataset.engine as de | |||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||
| from mindspore import log as logger | |||
| from util import visualize_list, diff_mse, save_and_check_md5 | |||
| @@ -26,9 +27,9 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" | |||
| GENERATE_GOLDEN = False | |||
| def test_equalize(plot=False): | |||
| def test_equalize_py(plot=False): | |||
| """ | |||
| Test Equalize | |||
| Test Equalize py op | |||
| """ | |||
| logger.info("Test Equalize") | |||
| @@ -83,9 +84,141 @@ def test_equalize(plot=False): | |||
| visualize_list(images_original, images_equalize) | |||
| def test_equalize_md5(): | |||
| def test_equalize_c(plot=False): | |||
| """ | |||
| Test Equalize with md5 check | |||
| Test Equalize Cpp op | |||
| """ | |||
| logger.info("Test Equalize cpp op") | |||
| # Original Images | |||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||
| transforms_original = [C.Decode(), C.Resize(size=[224, 224])] | |||
| ds_original = ds.map(input_columns="image", | |||
| operations=transforms_original) | |||
| ds_original = ds_original.batch(512) | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = image | |||
| else: | |||
| images_original = np.append(images_original, | |||
| image, | |||
| axis=0) | |||
| # Equalize Images | |||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||
| transform_equalize = [C.Decode(), C.Resize(size=[224, 224]), | |||
| C.Equalize()] | |||
| ds_equalize = ds.map(input_columns="image", | |||
| operations=transform_equalize) | |||
| ds_equalize = ds_equalize.batch(512) | |||
| for idx, (image, _) in enumerate(ds_equalize): | |||
| if idx == 0: | |||
| images_equalize = image | |||
| else: | |||
| images_equalize = np.append(images_equalize, | |||
| image, | |||
| axis=0) | |||
| if plot: | |||
| visualize_list(images_original, images_equalize) | |||
| num_samples = images_original.shape[0] | |||
| mse = np.zeros(num_samples) | |||
| for i in range(num_samples): | |||
| mse[i] = diff_mse(images_equalize[i], images_original[i]) | |||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||
| def test_equalize_py_c(plot=False): | |||
| """ | |||
| Test Equalize Cpp op and python op | |||
| """ | |||
| logger.info("Test Equalize cpp and python op") | |||
| # equalize Images in cpp | |||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||
| ds = ds.map(input_columns=["image"], | |||
| operations=[C.Decode(), C.Resize((224, 224))]) | |||
| ds_c_equalize = ds.map(input_columns="image", | |||
| operations=C.Equalize()) | |||
| ds_c_equalize = ds_c_equalize.batch(512) | |||
| for idx, (image, _) in enumerate(ds_c_equalize): | |||
| if idx == 0: | |||
| images_c_equalize = image | |||
| else: | |||
| images_c_equalize = np.append(images_c_equalize, | |||
| image, | |||
| axis=0) | |||
| # Equalize images in python | |||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||
| ds = ds.map(input_columns=["image"], | |||
| operations=[C.Decode(), C.Resize((224, 224))]) | |||
| transforms_p_equalize = F.ComposeOp([lambda img: img.astype(np.uint8), | |||
| F.ToPIL(), | |||
| F.Equalize(), | |||
| np.array]) | |||
| ds_p_equalize = ds.map(input_columns="image", | |||
| operations=transforms_p_equalize()) | |||
| ds_p_equalize = ds_p_equalize.batch(512) | |||
| for idx, (image, _) in enumerate(ds_p_equalize): | |||
| if idx == 0: | |||
| images_p_equalize = image | |||
| else: | |||
| images_p_equalize = np.append(images_p_equalize, | |||
| image, | |||
| axis=0) | |||
| num_samples = images_c_equalize.shape[0] | |||
| mse = np.zeros(num_samples) | |||
| for i in range(num_samples): | |||
| mse[i] = diff_mse(images_p_equalize[i], images_c_equalize[i]) | |||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||
| if plot: | |||
| visualize_list(images_c_equalize, images_p_equalize, visualize_mode=2) | |||
| def test_equalize_one_channel(): | |||
| """ | |||
| Test Equalize cpp op with one channel image | |||
| """ | |||
| logger.info("Test Equalize C Op With One Channel Images") | |||
| c_op = C.Equalize() | |||
| try: | |||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||
| ds = ds.map(input_columns=["image"], | |||
| operations=[C.Decode(), | |||
| C.Resize((224, 224)), | |||
| lambda img: np.array(img[:, :, 0])]) | |||
| ds.map(input_columns="image", | |||
| operations=c_op) | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "The shape" in str(e) | |||
| def test_equalize_md5_py(): | |||
| """ | |||
| Test Equalize py op with md5 check | |||
| """ | |||
| logger.info("Test Equalize") | |||
| @@ -101,6 +234,31 @@ def test_equalize_md5(): | |||
| save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_equalize_md5_c(): | |||
| """ | |||
| Test Equalize cpp op with md5 check | |||
| """ | |||
| logger.info("Test Equalize cpp op with md5 check") | |||
| # Generate dataset | |||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||
| transforms_equalize = [C.Decode(), | |||
| C.Resize(size=[224, 224]), | |||
| C.Equalize(), | |||
| F.ToTensor()] | |||
| data = ds.map(input_columns="image", operations=transforms_equalize) | |||
| # Compare with expected md5 from images | |||
| filename = "equalize_01_result_c.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| if __name__ == "__main__": | |||
| test_equalize(plot=True) | |||
| test_equalize_md5() | |||
| test_equalize_py(plot=False) | |||
| test_equalize_c(plot=False) | |||
| test_equalize_py_c(plot=False) | |||
| test_equalize_one_channel() | |||
| test_equalize_md5_py() | |||
| test_equalize_md5_c() | |||