Merge pull request !3077 from ava/inverttags/v0.6.0-beta
| @@ -54,6 +54,7 @@ | |||||
| #include "minddata/dataset/kernels/image/decode_op.h" | #include "minddata/dataset/kernels/image/decode_op.h" | ||||
| #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" | #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" | ||||
| #include "minddata/dataset/kernels/image/image_utils.h" | #include "minddata/dataset/kernels/image/image_utils.h" | ||||
| #include "minddata/dataset/kernels/image/invert_op.h" | |||||
| #include "minddata/dataset/kernels/image/normalize_op.h" | #include "minddata/dataset/kernels/image/normalize_op.h" | ||||
| #include "minddata/dataset/kernels/image/pad_op.h" | #include "minddata/dataset/kernels/image/pad_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_color_adjust_op.h" | #include "minddata/dataset/kernels/image/random_color_adjust_op.h" | ||||
| @@ -366,6 +367,10 @@ void bindTensorOps1(py::module *m) { | |||||
| .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), | .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), | ||||
| py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); | py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); | ||||
| (void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp", | |||||
| "Tensor operation to apply invert on RGB images.") | |||||
| .def(py::init<>()); | |||||
| (void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>( | (void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>( | ||||
| *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.") | *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.") | ||||
| .def(py::init<float, float>(), py::arg("rescale"), py::arg("shift")); | .def(py::init<float, float>(), py::arg("rescale"), py::arg("shift")); | ||||
| @@ -6,6 +6,7 @@ add_library(kernels-image OBJECT | |||||
| decode_op.cc | decode_op.cc | ||||
| hwc_to_chw_op.cc | hwc_to_chw_op.cc | ||||
| image_utils.cc | image_utils.cc | ||||
| invert_op.cc | |||||
| normalize_op.cc | normalize_op.cc | ||||
| pad_op.cc | pad_op.cc | ||||
| random_color_adjust_op.cc | random_color_adjust_op.cc | ||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/invert_op.h" | |||||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // only supports RGB images | |||||
| Status InvertOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | |||||
| try { | |||||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||||
| cv::Mat input_img = input_cv->mat(); | |||||
| if (!input_cv->mat().data) { | |||||
| RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); | |||||
| } | |||||
| if (input_cv->Rank() != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("Shape not <H,W,C>"); | |||||
| } | |||||
| int num_channels = input_cv->shape()[2]; | |||||
| if (num_channels != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("The shape is incorrect: num of channels != 3"); | |||||
| } | |||||
| auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type()); | |||||
| RETURN_UNEXPECTED_IF_NULL(output_cv); | |||||
| output_cv->mat() = cv::Scalar::all(255) - input_img; | |||||
| *output = std::static_pointer_cast<Tensor>(output_cv); | |||||
| } | |||||
| catch (const cv::Exception &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Error in invert"); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_KERNELS_IMAGE_INVERT_OP_H | |||||
| #define DATASET_KERNELS_IMAGE_INVERT_OP_H | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/kernels/tensor_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class InvertOp : public TensorOp { | |||||
| public: | |||||
| InvertOp() {} | |||||
| ~InvertOp() = default; | |||||
| // Description: A function that prints info about the node | |||||
| void Print(std::ostream &out) const override { out << Name(); } | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| std::string Name() const override { return kInvertOp; } | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_KERNELS_IMAGE_INVERT_OP_H | |||||
| @@ -92,6 +92,7 @@ constexpr char kDecodeOp[] = "DecodeOp"; | |||||
| constexpr char kCenterCropOp[] = "CenterCropOp"; | constexpr char kCenterCropOp[] = "CenterCropOp"; | ||||
| constexpr char kCutOutOp[] = "CutOutOp"; | constexpr char kCutOutOp[] = "CutOutOp"; | ||||
| constexpr char kHwcToChwOp[] = "HwcToChwOp"; | constexpr char kHwcToChwOp[] = "HwcToChwOp"; | ||||
| constexpr char kInvertOp[] = "InvertOp"; | |||||
| constexpr char kNormalizeOp[] = "NormalizeOp"; | constexpr char kNormalizeOp[] = "NormalizeOp"; | ||||
| constexpr char kPadOp[] = "PadOp"; | constexpr char kPadOp[] = "PadOp"; | ||||
| constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; | constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; | ||||
| @@ -71,6 +71,13 @@ def parse_padding(padding): | |||||
| return padding | return padding | ||||
| class Invert(cde.InvertOp): | |||||
| """ | |||||
| Apply invert on input image in RGB mode. | |||||
| does not have input arguments. | |||||
| """ | |||||
| class Decode(cde.DecodeOp): | class Decode(cde.DecodeOp): | ||||
| """ | """ | ||||
| Decode the input image in RGB mode. | Decode the input image in RGB mode. | ||||
| @@ -19,18 +19,20 @@ import numpy as np | |||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| import mindspore.dataset.transforms.vision.py_transforms as F | import mindspore.dataset.transforms.vision.py_transforms as F | ||||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import visualize_list, save_and_check_md5 | |||||
| from util import visualize_list, save_and_check_md5, diff_mse | |||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | DATA_DIR = "../data/dataset/testImageNetData/train/" | ||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| def test_invert(plot=False): | |||||
| def test_invert_py(plot=False): | |||||
| """ | """ | ||||
| Test Invert | |||||
| Test Invert python op | |||||
| """ | """ | ||||
| logger.info("Test Invert") | |||||
| logger.info("Test Invert Python op") | |||||
| # Original Images | # Original Images | ||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ||||
| @@ -52,7 +54,7 @@ def test_invert(plot=False): | |||||
| np.transpose(image, (0, 2, 3, 1)), | np.transpose(image, (0, 2, 3, 1)), | ||||
| axis=0) | axis=0) | ||||
| # Color Inverted Images | |||||
| # Color Inverted Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ||||
| transforms_invert = F.ComposeOp([F.Decode(), | transforms_invert = F.ComposeOp([F.Decode(), | ||||
| @@ -83,11 +85,143 @@ def test_invert(plot=False): | |||||
| visualize_list(images_original, images_invert) | visualize_list(images_original, images_invert) | ||||
| def test_invert_md5(): | |||||
| def test_invert_c(plot=False): | |||||
| """ | |||||
| Test Invert Cpp op | |||||
| """ | |||||
| logger.info("Test Invert cpp op") | |||||
| # Original Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = [C.Decode(), C.Resize(size=[224, 224])] | |||||
| ds_original = ds.map(input_columns="image", | |||||
| operations=transforms_original) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = image | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| image, | |||||
| axis=0) | |||||
| # Invert Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transform_invert = [C.Decode(), C.Resize(size=[224, 224]), | |||||
| C.Invert()] | |||||
| ds_invert = ds.map(input_columns="image", | |||||
| operations=transform_invert) | |||||
| ds_invert = ds_invert.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_invert): | |||||
| if idx == 0: | |||||
| images_invert = image | |||||
| else: | |||||
| images_invert = np.append(images_invert, | |||||
| image, | |||||
| axis=0) | |||||
| if plot: | |||||
| visualize_list(images_original, images_invert) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = diff_mse(images_invert[i], images_original[i]) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| def test_invert_py_c(plot=False): | |||||
| """ | |||||
| Test Invert Cpp op and python op | |||||
| """ | |||||
| logger.info("Test Invert cpp and python op") | |||||
| # Invert Images in cpp | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| ds = ds.map(input_columns=["image"], | |||||
| operations=[C.Decode(), C.Resize((224, 224))]) | |||||
| ds_c_invert = ds.map(input_columns="image", | |||||
| operations=C.Invert()) | |||||
| ds_c_invert = ds_c_invert.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_c_invert): | |||||
| if idx == 0: | |||||
| images_c_invert = image | |||||
| else: | |||||
| images_c_invert = np.append(images_c_invert, | |||||
| image, | |||||
| axis=0) | |||||
| # invert images in python | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| ds = ds.map(input_columns=["image"], | |||||
| operations=[C.Decode(), C.Resize((224, 224))]) | |||||
| transforms_p_invert = F.ComposeOp([lambda img: img.astype(np.uint8), | |||||
| F.ToPIL(), | |||||
| F.Invert(), | |||||
| np.array]) | |||||
| ds_p_invert = ds.map(input_columns="image", | |||||
| operations=transforms_p_invert()) | |||||
| ds_p_invert = ds_p_invert.batch(512) | |||||
| for idx, (image, _) in enumerate(ds_p_invert): | |||||
| if idx == 0: | |||||
| images_p_invert = image | |||||
| else: | |||||
| images_p_invert = np.append(images_p_invert, | |||||
| image, | |||||
| axis=0) | |||||
| num_samples = images_c_invert.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = diff_mse(images_p_invert[i], images_c_invert[i]) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize_list(images_c_invert, images_p_invert, visualize_mode=2) | |||||
| def test_invert_one_channel(): | |||||
| """ | """ | ||||
| Test Invert with md5 check | |||||
| Test Invert cpp op with one channel image | |||||
| """ | |||||
| logger.info("Test Invert C Op With One Channel Images") | |||||
| c_op = C.Invert() | |||||
| try: | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| ds = ds.map(input_columns=["image"], | |||||
| operations=[C.Decode(), | |||||
| C.Resize((224, 224)), | |||||
| lambda img: np.array(img[:, :, 0])]) | |||||
| ds.map(input_columns="image", | |||||
| operations=c_op) | |||||
| except RuntimeError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "The shape" in str(e) | |||||
| def test_invert_md5_py(): | |||||
| """ | """ | ||||
| logger.info("Test Invert with md5 check") | |||||
| Test Invert python op with md5 check | |||||
| """ | |||||
| logger.info("Test Invert python op with md5 check") | |||||
| # Generate dataset | # Generate dataset | ||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ||||
| @@ -98,10 +232,34 @@ def test_invert_md5(): | |||||
| data = ds.map(input_columns="image", operations=transforms_invert()) | data = ds.map(input_columns="image", operations=transforms_invert()) | ||||
| # Compare with expected md5 from images | # Compare with expected md5 from images | ||||
| filename = "invert_01_result.npz" | |||||
| filename = "invert_01_result_py.npz" | |||||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_invert_md5_c(): | |||||
| """ | |||||
| Test Invert cpp op with md5 check | |||||
| """ | |||||
| logger.info("Test Invert cpp op with md5 check") | |||||
| # Generate dataset | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_invert = [C.Decode(), | |||||
| C.Resize(size=[224, 224]), | |||||
| C.Invert(), | |||||
| F.ToTensor()] | |||||
| data = ds.map(input_columns="image", operations=transforms_invert) | |||||
| # Compare with expected md5 from images | |||||
| filename = "invert_01_result_c.npz" | |||||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_invert(plot=True) | |||||
| test_invert_md5() | |||||
| test_invert_py(plot=False) | |||||
| test_invert_c(plot=False) | |||||
| test_invert_py_c(plot=False) | |||||
| test_invert_one_channel() | |||||
| test_invert_md5_py() | |||||
| test_invert_md5_c() | |||||