From: @xiefangqi Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -31,6 +31,7 @@ | |||
| #include "minddata/dataset/kernels/image/invert_op.h" | |||
| #include "minddata/dataset/kernels/image/mixup_batch_op.h" | |||
| #include "minddata/dataset/kernels/image/normalize_op.h" | |||
| #include "minddata/dataset/kernels/image/normalize_pad_op.h" | |||
| #include "minddata/dataset/kernels/image/pad_op.h" | |||
| #include "minddata/dataset/kernels/image/random_affine_op.h" | |||
| #include "minddata/dataset/kernels/image/random_color_op.h" | |||
| @@ -71,6 +72,11 @@ PYBIND_REGISTER(NormalizeOp, 1, ([](const py::module *m) { | |||
| .def(py::init<float, float, float, float, float, float>()); | |||
| })); | |||
| PYBIND_REGISTER(NormalizePadOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<NormalizePadOp, TensorOp, std::shared_ptr<NormalizePadOp>>(*m, "NormalizePadOp") | |||
| .def(py::init<float, float, float, float, float, float, std::string>()); | |||
| })); | |||
| PYBIND_REGISTER( | |||
| EqualizeOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>(*m, "EqualizeOp").def(py::init<>()); | |||
| @@ -38,6 +38,7 @@ | |||
| #include "minddata/dataset/kernels/image/mixup_batch_op.h" | |||
| #endif | |||
| #include "minddata/dataset/kernels/image/normalize_op.h" | |||
| #include "minddata/dataset/kernels/image/normalize_pad_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/kernels/image/pad_op.h" | |||
| #include "minddata/dataset/kernels/image/random_affine_op.h" | |||
| @@ -169,6 +170,14 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // Function to create NormalizePadOperation. | |||
| std::shared_ptr<NormalizePadOperation> NormalizePad(const std::vector<float> &mean, const std::vector<float> &std, | |||
| const std::string &dtype) { | |||
| auto op = std::make_shared<NormalizePadOperation>(mean, std, dtype); | |||
| // Input validation | |||
| return op->ValidateParams() ? op : nullptr; | |||
| } | |||
| // Function to create PadOperation. | |||
| std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, | |||
| BorderType padding_mode) { | |||
| @@ -668,7 +677,7 @@ Status NormalizeOperation::ValidateParams() { | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (mean_[i] < 0.0f || mean_[i] > 255.0f || CmpFloat(mean_[i], 0.0f)) { | |||
| if (mean_[i] < 0.0f || mean_[i] > 255.0f) { | |||
| std::string err_msg = "Normalize: mean vector has incorrect value: " + std::to_string(mean_[i]); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| @@ -682,6 +691,47 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() { | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // NormalizePadOperation | |||
| NormalizePadOperation::NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, | |||
| const std::string &dtype) | |||
| : mean_(mean), std_(std), dtype_(dtype) {} | |||
| Status NormalizePadOperation::ValidateParams() { | |||
| if (mean_.size() != 3) { | |||
| std::string err_msg = "NormalizePad: mean vector has incorrect size: " + std::to_string(mean_.size()); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (std_.size() != 3) { | |||
| std::string err_msg = "NormalizePad: std vector has incorrect size: " + std::to_string(std_.size()); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| // check std/mean value | |||
| for (int32_t i = 0; i < std_.size(); ++i) { | |||
| if (std_[i] < 0.0f || std_[i] > 255.0f || CmpFloat(std_[i], 0.0f)) { | |||
| std::string err_msg = "NormalizePad: std vector has incorrect value: " + std::to_string(std_[i]); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (mean_[i] < 0.0f || mean_[i] > 255.0f) { | |||
| std::string err_msg = "NormalizePad: mean vector has incorrect value: " + std::to_string(mean_[i]); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| } | |||
| if (dtype_ != "float32" && dtype_ != "float16") { | |||
| std::string err_msg = "NormalizePad: dtype must be float32 or float16, but got: " + dtype_; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<TensorOp> NormalizePadOperation::Build() { | |||
| return std::make_shared<NormalizePadOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2], dtype_); | |||
| } | |||
| // PadOperation | |||
| PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode) | |||
| : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {} | |||
| @@ -42,6 +42,7 @@ constexpr char kEqualizeOperation[] = "Equalize"; | |||
| constexpr char kHwcToChwOperation[] = "HwcToChw"; | |||
| constexpr char kInvertOperation[] = "Invert"; | |||
| constexpr char kMixUpBatchOperation[] = "MixUpBatch"; | |||
| constexpr char kNormalizePadOperation[] = "NormalizePad"; | |||
| constexpr char kPadOperation[] = "Pad"; | |||
| constexpr char kRandomAffineOperation[] = "RandomAffine"; | |||
| constexpr char kRandomColorAdjustOperation[] = "RandomColorAdjust"; | |||
| @@ -79,6 +80,7 @@ class EqualizeOperation; | |||
| class HwcToChwOperation; | |||
| class InvertOperation; | |||
| class MixUpBatchOperation; | |||
| class NormalizePadOperation; | |||
| class PadOperation; | |||
| class RandomAffineOperation; | |||
| class RandomColorOperation; | |||
| @@ -162,6 +164,19 @@ std::shared_ptr<InvertOperation> Invert(); | |||
| /// \return Shared pointer to the current TensorOperation. | |||
| std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha = 1); | |||
| /// \brief Function to create a NormalizePad TensorOperation. | |||
| /// \notes Normalize the input image with respect to mean and standard deviation and pad an extra | |||
| /// channel with value zero. | |||
| /// \param[in] mean A vector of mean values for each channel, w.r.t channel order. | |||
| /// The mean values must be in range [0.0, 255.0]. | |||
| /// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order. | |||
| /// The standard deviation values must be in range (0.0, 255.0] | |||
| /// \param[in] dtype The output datatype of Tensor. | |||
| /// The standard deviation values must be "float32" or "float16"(default = "float32") | |||
| /// \return Shared pointer to the current TensorOperation. | |||
| std::shared_ptr<NormalizePadOperation> NormalizePad(const std::vector<float> &mean, const std::vector<float> &std, | |||
| const std::string &dtype = "float32"); | |||
| /// \brief Function to create a Pad TensorOp | |||
| /// \notes Pads the image according to padding parameters | |||
| /// \param[in] padding A vector representing the number of pixels to pad the image | |||
| @@ -587,6 +602,25 @@ class MixUpBatchOperation : public TensorOperation { | |||
| float alpha_; | |||
| }; | |||
| class NormalizePadOperation : public TensorOperation { | |||
| public: | |||
| NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, | |||
| const std::string &dtype = "float32"); | |||
| ~NormalizePadOperation() = default; | |||
| std::shared_ptr<TensorOp> Build() override; | |||
| Status ValidateParams() override; | |||
| std::string Name() const override { return kNormalizePadOperation; } | |||
| private: | |||
| std::vector<float> mean_; | |||
| std::vector<float> std_; | |||
| std::string dtype_; | |||
| }; | |||
| class PadOperation : public TensorOperation { | |||
| public: | |||
| PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0}, | |||
| @@ -81,7 +81,7 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb = true); | |||
| /// \brief Function to create a Normalize TensorOperation. | |||
| /// \notes Normalize the input image with respect to mean and standard deviation. | |||
| /// \param[in] mean A vector of mean values for each channel, w.r.t channel order. | |||
| /// The mean values must be in range (0.0, 255.0]. | |||
| /// The mean values must be in range [0.0, 255.0]. | |||
| /// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order. | |||
| /// The standard deviation values must be in range (0.0, 255.0] | |||
| /// \return Shared pointer to the current TensorOperation. | |||
| @@ -18,6 +18,7 @@ add_library(kernels-image OBJECT | |||
| math_utils.cc | |||
| mixup_batch_op.cc | |||
| normalize_op.cc | |||
| normalize_pad_op.cc | |||
| pad_op.cc | |||
| posterize_op.cc | |||
| random_affine_op.cc | |||
| @@ -630,6 +630,57 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||
| } | |||
| } | |||
| Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||
| const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype) { | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||
| if (!(input_cv->mat().data && input_cv->Rank() == 3)) { | |||
| RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); | |||
| } | |||
| DataType tensor_type = DataType(DataType::DE_FLOAT32); | |||
| int compute_type = CV_32F; | |||
| int channel_type = CV_32FC1; | |||
| if (dtype == "float16") { | |||
| compute_type = CV_16F; | |||
| channel_type = CV_16FC1; | |||
| tensor_type = DataType(DataType::DE_FLOAT16); | |||
| } | |||
| cv::Mat in_image = input_cv->mat(); | |||
| std::shared_ptr<CVTensor> output_cv; | |||
| TensorShape new_shape({input_cv->shape()[0], input_cv->shape()[1], input_cv->shape()[2] + 1}); | |||
| RETURN_IF_NOT_OK(CVTensor::CreateEmpty(new_shape, tensor_type, &output_cv)); | |||
| mean->Squeeze(); | |||
| if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) { | |||
| std::string err_msg = "Mean tensor should be of size 3 and type float."; | |||
| return Status(StatusCode::kShapeMisMatch, err_msg); | |||
| } | |||
| std->Squeeze(); | |||
| if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) { | |||
| std::string err_msg = "Std tensor should be of size 3 and type float."; | |||
| return Status(StatusCode::kShapeMisMatch, err_msg); | |||
| } | |||
| try { | |||
| // NOTE: We are assuming the input image is in RGB and the mean | |||
| // and std are in RGB | |||
| std::vector<cv::Mat> rgb; | |||
| cv::split(in_image, rgb); | |||
| if (rgb.size() != 3) { | |||
| RETURN_STATUS_UNEXPECTED("Input image is not in RGB."); | |||
| } | |||
| for (uint8_t i = 0; i < 3; i++) { | |||
| float mean_c, std_c; | |||
| RETURN_IF_NOT_OK(mean->GetItemAt<float>(&mean_c, {i})); | |||
| RETURN_IF_NOT_OK(std->GetItemAt<float>(&std_c, {i})); | |||
| rgb[i].convertTo(rgb[i], compute_type, 1.0 / std_c, (-mean_c / std_c)); | |||
| } | |||
| rgb.push_back(cv::Mat::zeros(in_image.rows, in_image.cols, channel_type)); | |||
| cv::merge(rgb, output_cv->mat()); | |||
| *output = std::static_pointer_cast<Tensor>(output_cv); | |||
| return Status::OK(); | |||
| } catch (const cv::Exception &e) { | |||
| RETURN_STATUS_UNEXPECTED("Unexpected error in NormalizePad"); | |||
| } | |||
| } | |||
| Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &alpha) { | |||
| try { | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||
| @@ -185,6 +185,15 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||
| Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||
| const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std); | |||
| /// \brief Returns Normalized and paded image | |||
| /// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor. | |||
| /// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order | |||
| /// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order | |||
| /// \param dtype: output dtype | |||
| /// \param output: Normalized image Tensor and pad an extra channel, return a dtype Tensor | |||
| Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||
| const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype); | |||
| /// \brief Returns image with adjusted brightness. | |||
| /// \param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor. | |||
| /// \param alpha: Alpha value to adjust brightness by. Should be a positive number. | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/kernels/image/normalize_pad_op.h" | |||
| #include <random> | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| NormalizePadOp::NormalizePadOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b, | |||
| std::string dtype) { | |||
| Status s = Tensor::CreateFromVector<float>({mean_r, mean_g, mean_b}, &mean_); | |||
| if (s.IsError()) { | |||
| MS_LOG(ERROR) << "Could not create mean tensor."; | |||
| } | |||
| s = Tensor::CreateFromVector<float>({std_r, std_g, std_b}, &std_); | |||
| if (s.IsError()) { | |||
| MS_LOG(ERROR) << "Could not create std tensor."; | |||
| } | |||
| dtype_ = dtype; | |||
| } | |||
| Status NormalizePadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| // Doing the normalization + pad | |||
| return NormalizePad(input, output, mean_, std_, dtype_); | |||
| } | |||
| void NormalizePadOp::Print(std::ostream &out) const { | |||
| out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_PAD_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_PAD_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class NormalizePadOp : public TensorOp { | |||
| public: | |||
| NormalizePadOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b, | |||
| std::string dtype = "float32"); | |||
| ~NormalizePadOp() override = default; | |||
| void Print(std::ostream &out) const override; | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| std::string Name() const override { return kNormalizePadOp; } | |||
| private: | |||
| std::shared_ptr<Tensor> mean_; | |||
| std::shared_ptr<Tensor> std_; | |||
| std::string dtype_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ | |||
| @@ -62,6 +62,7 @@ constexpr char kHwcToChwOp[] = "HwcToChwOp"; | |||
| constexpr char kInvertOp[] = "InvertOp"; | |||
| constexpr char kMixUpBatchOp[] = "MixUpBatchOp"; | |||
| constexpr char kNormalizeOp[] = "NormalizeOp"; | |||
| constexpr char kNormalizePadOp[] = "NormalizePadOp"; | |||
| constexpr char kPadOp[] = "PadOp"; | |||
| constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; | |||
| constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; | |||
| @@ -50,8 +50,8 @@ import mindspore._c_dataengine as cde | |||
| from .utils import Inter, Border, ImageBatchFormat | |||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | |||
| check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||
| check_range, check_resize, check_rescale, check_pad, check_cutout, \ | |||
| check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \ | |||
| check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \ | |||
| check_uniform_augment_cpp, \ | |||
| check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ | |||
| check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \ | |||
| @@ -319,6 +319,50 @@ class Normalize(cde.NormalizeOp): | |||
| return img.as_array() | |||
| class NormalizePad(cde.NormalizePadOp): | |||
| """ | |||
| Normalize the input image with respect to mean and standard deviation then pad an extra channel with value zero. | |||
| Args: | |||
| mean (sequence): List or tuple of mean values for each channel, with respect to channel order. | |||
| The mean values must be in range (0.0, 255.0]. | |||
| std (sequence): List or tuple of standard deviations for each channel, with respect to channel order. | |||
| The standard deviation values must be in range (0.0, 255.0]. | |||
| dtype (str): Set the output data type of normalized image (default is "float32"). | |||
| Examples: | |||
| >>> import mindspore.dataset.vision.c_transforms as c_vision | |||
| >>> | |||
| >>> decode_op = c_vision.Decode() | |||
| >>> normalize_op = c_vision.NormalizePad(mean=[121.0, 115.0, 100.0], std=[70.0, 68.0, 71.0], dtype="float32") | |||
| >>> transforms_list = [decode_op, normalize_pad_op] | |||
| >>> data1 = data1.map(operations=transforms_list, input_columns=["image"]) | |||
| """ | |||
| @check_normalizepad_c | |||
| def __init__(self, mean, std, dtype="float32"): | |||
| self.mean = mean | |||
| self.std = std | |||
| self.dtype = dtype | |||
| super().__init__(*mean, *std, dtype) | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| Args: | |||
| img (NumPy or PIL image): Image array to be normalizepad. | |||
| Returns: | |||
| img (NumPy), NormalizePaded Image array. | |||
| """ | |||
| if not isinstance(img, (np.ndarray, Image.Image)): | |||
| raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) | |||
| normalize_pad = cde.Execute(cde.NormalizePadOp(*self.mean, *self.std, self.dtype)) | |||
| img = normalize_pad(cde.Tensor(np.asarray(img))) | |||
| return img.as_array() | |||
| class RandomAffine(cde.RandomAffineOp): | |||
| """ | |||
| Apply Random affine transformation to the input image. | |||
| @@ -28,7 +28,7 @@ from PIL import Image | |||
| from . import py_transforms_util as util | |||
| from .c_transforms import parse_padding | |||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | |||
| check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||
| check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||
| check_ten_crop, check_num_channels, check_pad, \ | |||
| check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \ | |||
| check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast | |||
| @@ -231,6 +231,49 @@ class Normalize: | |||
| return util.normalize(img, self.mean, self.std) | |||
| class NormalizePad: | |||
| """ | |||
| Normalize the input NumPy image array of shape (C, H, W) with the given mean and standard deviation | |||
| then pad an extra channel with value zero. | |||
| The values of the array need to be in the range (0.0, 1.0]. | |||
| Args: | |||
| mean (sequence): List or tuple of mean values for each channel, with respect to channel order. | |||
| The mean values must be in the range (0.0, 1.0]. | |||
| std (sequence): List or tuple of standard deviations for each channel, w.r.t. channel order. | |||
| The standard deviation values must be in the range (0.0, 1.0]. | |||
| dtype (str): Set the output data type of image (default is "float32"). | |||
| Examples: | |||
| >>> import mindspore.dataset.vision.py_transforms as py_vision | |||
| >>> from mindspore.dataset.transforms.py_transforms import Compose | |||
| >>> | |||
| >>> Compose([py_vision.Decode(), | |||
| >>> py_vision.RandomHorizontalFlip(0.5), | |||
| >>> py_vision.ToTensor(), | |||
| >>> py_vision.NormalizePad((0.491, 0.482, 0.447), (0.247, 0.243, 0.262), "float32")]) | |||
| """ | |||
| @check_normalizepad_py | |||
| def __init__(self, mean, std, dtype="float32"): | |||
| self.mean = mean | |||
| self.std = std | |||
| self.dtype = dtype | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| Args: | |||
| img (numpy.ndarray): Image array to be normalizepad. | |||
| Returns: | |||
| img (numpy.ndarray), NormalizePaded Image array. | |||
| """ | |||
| return util.normalize(img, self.mean, self.std, pad_channel=True, dtype=self.dtype) | |||
| class RandomCrop: | |||
| """ | |||
| Crop the input PIL image at a random location. | |||
| @@ -42,7 +42,7 @@ def is_pil(img): | |||
| return isinstance(img, Image.Image) | |||
| def normalize(img, mean, std): | |||
| def normalize(img, mean, std, pad_channel=False, dtype="float32"): | |||
| """ | |||
| Normalize the image between [0, 1] with respect to mean and standard deviation. | |||
| @@ -50,6 +50,8 @@ def normalize(img, mean, std): | |||
| img (numpy.ndarray): Image array of shape CHW to be normalized. | |||
| mean (list): List of mean values for each channel, w.r.t channel order. | |||
| std (list): List of standard deviations for each channel, w.r.t. channel order. | |||
| pad_channel (bool): Whether to pad a extra channel with value zero. | |||
| dtype (str): Output datatype of normalize, only worked when pad_channel is True. (default is "float32") | |||
| Returns: | |||
| img (numpy.ndarray), Normalized image. | |||
| @@ -72,7 +74,13 @@ def normalize(img, mean, std): | |||
| mean = np.array(mean, dtype=img.dtype) | |||
| std = np.array(std, dtype=img.dtype) | |||
| return (img - mean[:, None, None]) / std[:, None, None] | |||
| image = (img - mean[:, None, None]) / std[:, None, None] | |||
| if pad_channel: | |||
| zeros = np.zeros([1, image.shape[1], image.shape[2]], dtype=np.float32) | |||
| image = np.concatenate((image, zeros), axis=0) | |||
| if dtype == "float16": | |||
| image = image.astype(np.float16) | |||
| return image | |||
| def decode(img): | |||
| @@ -294,6 +294,40 @@ def check_normalize_py(method): | |||
| return new_method | |||
| def check_normalizepad_c(method): | |||
| """A wrapper that wraps a parameter checker around the original function(normalizepad operation written in C++).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs) | |||
| check_normalize_c_param(mean, std) | |||
| if not isinstance(dtype, str): | |||
| raise TypeError("dtype should be string.") | |||
| if dtype not in ["float32", "float16"]: | |||
| raise ValueError("dtype only support float32 or float16.") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_normalizepad_py(method): | |||
| """A wrapper that wraps a parameter checker around the original function(normalizepad operation written in Python).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs) | |||
| check_normalize_py_param(mean, std) | |||
| if not isinstance(dtype, str): | |||
| raise TypeError("dtype should be string.") | |||
| if dtype not in ["float32", "float16"]: | |||
| raise ValueError("dtype only support float32 or float16.") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_random_crop(method): | |||
| """Wrapper method to check the parameters of random crop.""" | |||
| @@ -58,11 +58,6 @@ class MyTimeMonitor(Callback): | |||
| fps = self.batch_size / step_mseconds *1000 * self.size | |||
| print("Epoch time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ") | |||
| def pad(image): | |||
| zeros = np.zeros([224, 224, 1], dtype=np.uint8) | |||
| output = np.concatenate((image, zeros), axis=2) | |||
| return output | |||
| def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16"): | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True) | |||
| @@ -71,24 +66,25 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" | |||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | |||
| # define map operations | |||
| normalize_op = C.Normalize(mean=mean, std=std) | |||
| if dtype == "float16": | |||
| normalize_op = C.NormalizePad(mean=mean, std=std, dtype="float16") | |||
| if do_train: | |||
| trans = [ | |||
| C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | |||
| C.RandomHorizontalFlip(prob=0.5), | |||
| C.Normalize(mean=mean, std=std), | |||
| normalize_op, | |||
| ] | |||
| else: | |||
| trans = [ | |||
| C.Decode(), | |||
| C.Resize(256), | |||
| C.CenterCrop(image_size), | |||
| C.Normalize(mean=mean, std=std), | |||
| normalize_op, | |||
| ] | |||
| if dtype == "fp32": | |||
| trans.append(C.HWC2CHW()) | |||
| ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=4) | |||
| if dtype == "fp16": | |||
| ds = ds.map(operations=pad, input_columns="image", num_parallel_workers=4) | |||
| ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8) | |||
| # apply batch operations | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| # apply dataset repeat operation | |||
| @@ -932,6 +932,70 @@ TEST_F(MindDataTestPipeline, TestNormalizeFail) { | |||
| EXPECT_EQ(normalize, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestNormalizePad) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNormalizePad."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create a Repeat operation on ds | |||
| int32_t repeat_num = 2; | |||
| ds = ds->Repeat(repeat_num); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorOperation> normalizepad = vision::NormalizePad({121.0, 115.0, 100.0}, {70.0, 68.0, 71.0}, | |||
| "float32"); | |||
| EXPECT_NE(normalizepad, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({normalizepad}); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| EXPECT_EQ(image->shape()[2], 4); | |||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(i, 20); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestNormalizePadFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNormalizePadFail with invalid parameters."; | |||
| // std value at 0.0 | |||
| std::shared_ptr<TensorOperation> normalizepad = | |||
| mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {0.0, 68.0, 71.0}); | |||
| EXPECT_EQ(normalizepad, nullptr); | |||
| // normalizepad with 2 values (not 3 values) for mean | |||
| normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0}, {70.0, 68.0, 71.0}); | |||
| EXPECT_EQ(normalizepad, nullptr); | |||
| // normalizepad with 2 values (not 3 values) for standard deviation | |||
| normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {68.0, 71.0}); | |||
| EXPECT_EQ(normalizepad, nullptr); | |||
| // normalizepad with invalid dtype | |||
| normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {68.0, 71.0, 71.0}, "123"); | |||
| EXPECT_EQ(normalizepad, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestPad) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad."; | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "common/cvop_common.h" | |||
| #include "minddata/dataset/kernels/image/normalize_pad_op.h" | |||
| #include "minddata/dataset/core/cv_tensor.h" | |||
| #include "utils/log_adapter.h" | |||
| #include <opencv2/opencv.hpp> | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestNormalizePadOP : public UT::CVOP::CVOpCommon { | |||
| public: | |||
| MindDataTestNormalizePadOP() : CVOpCommon() {} | |||
| }; | |||
| TEST_F(MindDataTestNormalizePadOP, TestFloat32) { | |||
| MS_LOG(INFO) << "Doing TestNormalizePadOp::TestFloat32."; | |||
| std::shared_ptr<Tensor> output_tensor; | |||
| // Numbers are from the resnet50 model implementation | |||
| float mean[3] = {121.0, 115.0, 100.0}; | |||
| float std[3] = {70.0, 68.0, 71.0}; | |||
| // NormalizePad Op | |||
| std::unique_ptr<NormalizePadOp> op(new NormalizePadOp(mean[0], mean[1], mean[2], std[0], std[1], std[2], "float32")); | |||
| EXPECT_TRUE(op->OneToOne()); | |||
| Status s = op->Compute(input_tensor_, &output_tensor); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestNormalizePadOP, TestFloat16) { | |||
| MS_LOG(INFO) << "Doing TestNormalizePadOp::TestFloat16."; | |||
| std::shared_ptr<Tensor> output_tensor; | |||
| // Numbers are from the resnet50 model implementation | |||
| float mean[3] = {121.0, 115.0, 100.0}; | |||
| float std[3] = {70.0, 68.0, 71.0}; | |||
| // NormalizePad Op | |||
| std::unique_ptr<NormalizePadOp> op(new NormalizePadOp(mean[0], mean[1], mean[2], std[0], std[1], std[2], "float16")); | |||
| EXPECT_TRUE(op->OneToOne()); | |||
| Status s = op->Compute(input_tensor_, &output_tensor); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| } | |||
| @@ -0,0 +1,201 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing Normalize op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.py_transforms | |||
| import mindspore.dataset.vision.c_transforms as c_vision | |||
| import mindspore.dataset.vision.py_transforms as py_vision | |||
| from mindspore import log as logger | |||
| from util import diff_mse, visualize_image | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| GENERATE_GOLDEN = False | |||
| def normalizepad_np(image, mean, std): | |||
| """ | |||
| Apply the normalize+pad | |||
| """ | |||
| # DE decodes the image in RGB by deafult, hence | |||
| # the values here are in RGB | |||
| image = np.array(image, np.float32) | |||
| image = image - np.array(mean) | |||
| image = image * (1.0 / np.array(std)) | |||
| zeros = np.zeros([image.shape[0], image.shape[1], 1], dtype=np.float32) | |||
| output = np.concatenate((image, zeros), axis=2) | |||
| return output | |||
| def test_normalizepad_op_c(plot=False): | |||
| """ | |||
| Test NormalizePad in cpp transformations | |||
| """ | |||
| logger.info("Test Normalize in cpp") | |||
| mean = [121.0, 115.0, 100.0] | |||
| std = [70.0, 68.0, 71.0] | |||
| # define map operations | |||
| decode_op = c_vision.Decode() | |||
| normalizepad_op = c_vision.NormalizePad(mean, std) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(operations=decode_op, input_columns=["image"]) | |||
| data1 = data1.map(operations=normalizepad_op, input_columns=["image"]) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(operations=decode_op, input_columns=["image"]) | |||
| num_iter = 0 | |||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||
| image_de_normalized = item1["image"] | |||
| image_original = item2["image"] | |||
| image_np_normalized = normalizepad_np(image_original, mean, std) | |||
| mse = diff_mse(image_de_normalized, image_np_normalized) | |||
| logger.info("image_{}, mse: {}".format(num_iter + 1, mse)) | |||
| assert mse < 0.01 | |||
| if plot: | |||
| visualize_image(image_original, image_de_normalized, mse, image_np_normalized) | |||
| num_iter += 1 | |||
| def test_normalizepad_op_py(plot=False): | |||
| """ | |||
| Test NormalizePad in python transformations | |||
| """ | |||
| logger.info("Test Normalize in python") | |||
| mean = [0.475, 0.45, 0.392] | |||
| std = [0.275, 0.267, 0.278] | |||
| # define map operations | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) | |||
| normalizepad_op = py_vision.NormalizePad(mean, std) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(operations=transform, input_columns=["image"]) | |||
| data1 = data1.map(operations=normalizepad_op, input_columns=["image"]) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(operations=transform, input_columns=["image"]) | |||
| num_iter = 0 | |||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||
| image_de_normalized = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_np_normalized = (normalizepad_np(item2["image"].transpose(1, 2, 0), mean, std) * 255).astype(np.uint8) | |||
| image_original = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| mse = diff_mse(image_de_normalized, image_np_normalized) | |||
| logger.info("image_{}, mse: {}".format(num_iter + 1, mse)) | |||
| assert mse < 0.01 | |||
| if plot: | |||
| visualize_image(image_original, image_de_normalized, mse, image_np_normalized) | |||
| num_iter += 1 | |||
| def test_decode_normalizepad_op(): | |||
| """ | |||
| Test Decode op followed by NormalizePad op | |||
| """ | |||
| logger.info("Test [Decode, Normalize] in one Map") | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], num_parallel_workers=1, | |||
| shuffle=False) | |||
| # define map operations | |||
| decode_op = c_vision.Decode() | |||
| normalizepad_op = c_vision.NormalizePad([121.0, 115.0, 100.0], [70.0, 68.0, 71.0], "float16") | |||
| # apply map operations on images | |||
| data1 = data1.map(operations=[decode_op, normalizepad_op], input_columns=["image"]) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| logger.info("Looping inside iterator {}".format(num_iter)) | |||
| assert item["image"].dtype == np.float16 | |||
| num_iter += 1 | |||
| def test_normalizepad_exception_unequal_size_c(): | |||
| """ | |||
| Test NormalizePad in c transformation: len(mean) != len(std) | |||
| expected to raise ValueError | |||
| """ | |||
| logger.info("test_normalize_exception_unequal_size_c") | |||
| try: | |||
| _ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75, 75]) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "Length of mean and std must be equal." | |||
| try: | |||
| _ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75], 1) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "dtype should be string." | |||
| try: | |||
| _ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75], "") | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "dtype only support float32 or float16." | |||
| def test_normalizepad_exception_unequal_size_py(): | |||
| """ | |||
| Test NormalizePad in python transformation: len(mean) != len(std) | |||
| expected to raise ValueError | |||
| """ | |||
| logger.info("test_normalizepad_exception_unequal_size_py") | |||
| try: | |||
| _ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71, 0.72]) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "Length of mean and std must be equal." | |||
| try: | |||
| _ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71], 1) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "dtype should be string." | |||
| try: | |||
| _ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71], "") | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "dtype only support float32 or float16." | |||
| def test_normalizepad_exception_invalid_range_py(): | |||
| """ | |||
| Test NormalizePad in python transformation: value is not in range [0,1] | |||
| expected to raise ValueError | |||
| """ | |||
| logger.info("test_normalizepad_exception_invalid_range_py") | |||
| try: | |||
| _ = py_vision.NormalizePad([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e) | |||