From: @xiefangqi Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -31,6 +31,7 @@ | |||||
| #include "minddata/dataset/kernels/image/invert_op.h" | #include "minddata/dataset/kernels/image/invert_op.h" | ||||
| #include "minddata/dataset/kernels/image/mixup_batch_op.h" | #include "minddata/dataset/kernels/image/mixup_batch_op.h" | ||||
| #include "minddata/dataset/kernels/image/normalize_op.h" | #include "minddata/dataset/kernels/image/normalize_op.h" | ||||
| #include "minddata/dataset/kernels/image/normalize_pad_op.h" | |||||
| #include "minddata/dataset/kernels/image/pad_op.h" | #include "minddata/dataset/kernels/image/pad_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_affine_op.h" | #include "minddata/dataset/kernels/image/random_affine_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_color_op.h" | #include "minddata/dataset/kernels/image/random_color_op.h" | ||||
| @@ -71,6 +72,11 @@ PYBIND_REGISTER(NormalizeOp, 1, ([](const py::module *m) { | |||||
| .def(py::init<float, float, float, float, float, float>()); | .def(py::init<float, float, float, float, float, float>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(NormalizePadOp, 1, ([](const py::module *m) { | |||||
| (void)py::class_<NormalizePadOp, TensorOp, std::shared_ptr<NormalizePadOp>>(*m, "NormalizePadOp") | |||||
| .def(py::init<float, float, float, float, float, float, std::string>()); | |||||
| })); | |||||
| PYBIND_REGISTER( | PYBIND_REGISTER( | ||||
| EqualizeOp, 1, ([](const py::module *m) { | EqualizeOp, 1, ([](const py::module *m) { | ||||
| (void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>(*m, "EqualizeOp").def(py::init<>()); | (void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>(*m, "EqualizeOp").def(py::init<>()); | ||||
| @@ -38,6 +38,7 @@ | |||||
| #include "minddata/dataset/kernels/image/mixup_batch_op.h" | #include "minddata/dataset/kernels/image/mixup_batch_op.h" | ||||
| #endif | #endif | ||||
| #include "minddata/dataset/kernels/image/normalize_op.h" | #include "minddata/dataset/kernels/image/normalize_op.h" | ||||
| #include "minddata/dataset/kernels/image/normalize_pad_op.h" | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "minddata/dataset/kernels/image/pad_op.h" | #include "minddata/dataset/kernels/image/pad_op.h" | ||||
| #include "minddata/dataset/kernels/image/random_affine_op.h" | #include "minddata/dataset/kernels/image/random_affine_op.h" | ||||
| @@ -169,6 +170,14 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Function to create NormalizePadOperation. | |||||
| std::shared_ptr<NormalizePadOperation> NormalizePad(const std::vector<float> &mean, const std::vector<float> &std, | |||||
| const std::string &dtype) { | |||||
| auto op = std::make_shared<NormalizePadOperation>(mean, std, dtype); | |||||
| // Input validation | |||||
| return op->ValidateParams() ? op : nullptr; | |||||
| } | |||||
| // Function to create PadOperation. | // Function to create PadOperation. | ||||
| std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, | std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, | ||||
| BorderType padding_mode) { | BorderType padding_mode) { | ||||
| @@ -668,7 +677,7 @@ Status NormalizeOperation::ValidateParams() { | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (mean_[i] < 0.0f || mean_[i] > 255.0f || CmpFloat(mean_[i], 0.0f)) { | |||||
| if (mean_[i] < 0.0f || mean_[i] > 255.0f) { | |||||
| std::string err_msg = "Normalize: mean vector has incorrect value: " + std::to_string(mean_[i]); | std::string err_msg = "Normalize: mean vector has incorrect value: " + std::to_string(mean_[i]); | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| @@ -682,6 +691,47 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() { | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // NormalizePadOperation | |||||
| NormalizePadOperation::NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, | |||||
| const std::string &dtype) | |||||
| : mean_(mean), std_(std), dtype_(dtype) {} | |||||
| Status NormalizePadOperation::ValidateParams() { | |||||
| if (mean_.size() != 3) { | |||||
| std::string err_msg = "NormalizePad: mean vector has incorrect size: " + std::to_string(mean_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (std_.size() != 3) { | |||||
| std::string err_msg = "NormalizePad: std vector has incorrect size: " + std::to_string(std_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| // check std/mean value | |||||
| for (int32_t i = 0; i < std_.size(); ++i) { | |||||
| if (std_[i] < 0.0f || std_[i] > 255.0f || CmpFloat(std_[i], 0.0f)) { | |||||
| std::string err_msg = "NormalizePad: std vector has incorrect value: " + std::to_string(std_[i]); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (mean_[i] < 0.0f || mean_[i] > 255.0f) { | |||||
| std::string err_msg = "NormalizePad: mean vector has incorrect value: " + std::to_string(mean_[i]); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| if (dtype_ != "float32" && dtype_ != "float16") { | |||||
| std::string err_msg = "NormalizePad: dtype must be float32 or float16, but got: " + dtype_; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| std::shared_ptr<TensorOp> NormalizePadOperation::Build() { | |||||
| return std::make_shared<NormalizePadOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2], dtype_); | |||||
| } | |||||
| // PadOperation | // PadOperation | ||||
| PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode) | PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode) | ||||
| : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {} | : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {} | ||||
| @@ -42,6 +42,7 @@ constexpr char kEqualizeOperation[] = "Equalize"; | |||||
| constexpr char kHwcToChwOperation[] = "HwcToChw"; | constexpr char kHwcToChwOperation[] = "HwcToChw"; | ||||
| constexpr char kInvertOperation[] = "Invert"; | constexpr char kInvertOperation[] = "Invert"; | ||||
| constexpr char kMixUpBatchOperation[] = "MixUpBatch"; | constexpr char kMixUpBatchOperation[] = "MixUpBatch"; | ||||
| constexpr char kNormalizePadOperation[] = "NormalizePad"; | |||||
| constexpr char kPadOperation[] = "Pad"; | constexpr char kPadOperation[] = "Pad"; | ||||
| constexpr char kRandomAffineOperation[] = "RandomAffine"; | constexpr char kRandomAffineOperation[] = "RandomAffine"; | ||||
| constexpr char kRandomColorAdjustOperation[] = "RandomColorAdjust"; | constexpr char kRandomColorAdjustOperation[] = "RandomColorAdjust"; | ||||
| @@ -79,6 +80,7 @@ class EqualizeOperation; | |||||
| class HwcToChwOperation; | class HwcToChwOperation; | ||||
| class InvertOperation; | class InvertOperation; | ||||
| class MixUpBatchOperation; | class MixUpBatchOperation; | ||||
| class NormalizePadOperation; | |||||
| class PadOperation; | class PadOperation; | ||||
| class RandomAffineOperation; | class RandomAffineOperation; | ||||
| class RandomColorOperation; | class RandomColorOperation; | ||||
| @@ -162,6 +164,19 @@ std::shared_ptr<InvertOperation> Invert(); | |||||
| /// \return Shared pointer to the current TensorOperation. | /// \return Shared pointer to the current TensorOperation. | ||||
| std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha = 1); | std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha = 1); | ||||
| /// \brief Function to create a NormalizePad TensorOperation. | |||||
| /// \notes Normalize the input image with respect to mean and standard deviation and pad an extra | |||||
| /// channel with value zero. | |||||
| /// \param[in] mean A vector of mean values for each channel, w.r.t channel order. | |||||
| /// The mean values must be in range [0.0, 255.0]. | |||||
| /// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order. | |||||
| /// The standard deviation values must be in range (0.0, 255.0] | |||||
| /// \param[in] dtype The output datatype of Tensor. | |||||
| /// The standard deviation values must be "float32" or "float16"(default = "float32") | |||||
| /// \return Shared pointer to the current TensorOperation. | |||||
| std::shared_ptr<NormalizePadOperation> NormalizePad(const std::vector<float> &mean, const std::vector<float> &std, | |||||
| const std::string &dtype = "float32"); | |||||
| /// \brief Function to create a Pad TensorOp | /// \brief Function to create a Pad TensorOp | ||||
| /// \notes Pads the image according to padding parameters | /// \notes Pads the image according to padding parameters | ||||
| /// \param[in] padding A vector representing the number of pixels to pad the image | /// \param[in] padding A vector representing the number of pixels to pad the image | ||||
| @@ -587,6 +602,25 @@ class MixUpBatchOperation : public TensorOperation { | |||||
| float alpha_; | float alpha_; | ||||
| }; | }; | ||||
| class NormalizePadOperation : public TensorOperation { | |||||
| public: | |||||
| NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std, | |||||
| const std::string &dtype = "float32"); | |||||
| ~NormalizePadOperation() = default; | |||||
| std::shared_ptr<TensorOp> Build() override; | |||||
| Status ValidateParams() override; | |||||
| std::string Name() const override { return kNormalizePadOperation; } | |||||
| private: | |||||
| std::vector<float> mean_; | |||||
| std::vector<float> std_; | |||||
| std::string dtype_; | |||||
| }; | |||||
| class PadOperation : public TensorOperation { | class PadOperation : public TensorOperation { | ||||
| public: | public: | ||||
| PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0}, | PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0}, | ||||
| @@ -81,7 +81,7 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb = true); | |||||
| /// \brief Function to create a Normalize TensorOperation. | /// \brief Function to create a Normalize TensorOperation. | ||||
| /// \notes Normalize the input image with respect to mean and standard deviation. | /// \notes Normalize the input image with respect to mean and standard deviation. | ||||
| /// \param[in] mean A vector of mean values for each channel, w.r.t channel order. | /// \param[in] mean A vector of mean values for each channel, w.r.t channel order. | ||||
| /// The mean values must be in range (0.0, 255.0]. | |||||
| /// The mean values must be in range [0.0, 255.0]. | |||||
| /// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order. | /// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order. | ||||
| /// The standard deviation values must be in range (0.0, 255.0] | /// The standard deviation values must be in range (0.0, 255.0] | ||||
| /// \return Shared pointer to the current TensorOperation. | /// \return Shared pointer to the current TensorOperation. | ||||
| @@ -18,6 +18,7 @@ add_library(kernels-image OBJECT | |||||
| math_utils.cc | math_utils.cc | ||||
| mixup_batch_op.cc | mixup_batch_op.cc | ||||
| normalize_op.cc | normalize_op.cc | ||||
| normalize_pad_op.cc | |||||
| pad_op.cc | pad_op.cc | ||||
| posterize_op.cc | posterize_op.cc | ||||
| random_affine_op.cc | random_affine_op.cc | ||||
| @@ -630,6 +630,57 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||||
| } | } | ||||
| } | } | ||||
| Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||||
| const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype) { | |||||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||||
| if (!(input_cv->mat().data && input_cv->Rank() == 3)) { | |||||
| RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); | |||||
| } | |||||
| DataType tensor_type = DataType(DataType::DE_FLOAT32); | |||||
| int compute_type = CV_32F; | |||||
| int channel_type = CV_32FC1; | |||||
| if (dtype == "float16") { | |||||
| compute_type = CV_16F; | |||||
| channel_type = CV_16FC1; | |||||
| tensor_type = DataType(DataType::DE_FLOAT16); | |||||
| } | |||||
| cv::Mat in_image = input_cv->mat(); | |||||
| std::shared_ptr<CVTensor> output_cv; | |||||
| TensorShape new_shape({input_cv->shape()[0], input_cv->shape()[1], input_cv->shape()[2] + 1}); | |||||
| RETURN_IF_NOT_OK(CVTensor::CreateEmpty(new_shape, tensor_type, &output_cv)); | |||||
| mean->Squeeze(); | |||||
| if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) { | |||||
| std::string err_msg = "Mean tensor should be of size 3 and type float."; | |||||
| return Status(StatusCode::kShapeMisMatch, err_msg); | |||||
| } | |||||
| std->Squeeze(); | |||||
| if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) { | |||||
| std::string err_msg = "Std tensor should be of size 3 and type float."; | |||||
| return Status(StatusCode::kShapeMisMatch, err_msg); | |||||
| } | |||||
| try { | |||||
| // NOTE: We are assuming the input image is in RGB and the mean | |||||
| // and std are in RGB | |||||
| std::vector<cv::Mat> rgb; | |||||
| cv::split(in_image, rgb); | |||||
| if (rgb.size() != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("Input image is not in RGB."); | |||||
| } | |||||
| for (uint8_t i = 0; i < 3; i++) { | |||||
| float mean_c, std_c; | |||||
| RETURN_IF_NOT_OK(mean->GetItemAt<float>(&mean_c, {i})); | |||||
| RETURN_IF_NOT_OK(std->GetItemAt<float>(&std_c, {i})); | |||||
| rgb[i].convertTo(rgb[i], compute_type, 1.0 / std_c, (-mean_c / std_c)); | |||||
| } | |||||
| rgb.push_back(cv::Mat::zeros(in_image.rows, in_image.cols, channel_type)); | |||||
| cv::merge(rgb, output_cv->mat()); | |||||
| *output = std::static_pointer_cast<Tensor>(output_cv); | |||||
| return Status::OK(); | |||||
| } catch (const cv::Exception &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Unexpected error in NormalizePad"); | |||||
| } | |||||
| } | |||||
| Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &alpha) { | Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &alpha) { | ||||
| try { | try { | ||||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | ||||
| @@ -185,6 +185,15 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||||
| Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | ||||
| const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std); | const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std); | ||||
| /// \brief Returns Normalized and paded image | |||||
| /// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor. | |||||
| /// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order | |||||
| /// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order | |||||
| /// \param dtype: output dtype | |||||
| /// \param output: Normalized image Tensor and pad an extra channel, return a dtype Tensor | |||||
| Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||||
| const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype); | |||||
| /// \brief Returns image with adjusted brightness. | /// \brief Returns image with adjusted brightness. | ||||
| /// \param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor. | /// \param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor. | ||||
| /// \param alpha: Alpha value to adjust brightness by. Should be a positive number. | /// \param alpha: Alpha value to adjust brightness by. Should be a positive number. | ||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/normalize_pad_op.h" | |||||
| #include <random> | |||||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| NormalizePadOp::NormalizePadOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b, | |||||
| std::string dtype) { | |||||
| Status s = Tensor::CreateFromVector<float>({mean_r, mean_g, mean_b}, &mean_); | |||||
| if (s.IsError()) { | |||||
| MS_LOG(ERROR) << "Could not create mean tensor."; | |||||
| } | |||||
| s = Tensor::CreateFromVector<float>({std_r, std_g, std_b}, &std_); | |||||
| if (s.IsError()) { | |||||
| MS_LOG(ERROR) << "Could not create std tensor."; | |||||
| } | |||||
| dtype_ = dtype; | |||||
| } | |||||
| Status NormalizePadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | |||||
| // Doing the normalization + pad | |||||
| return NormalizePad(input, output, mean_, std_, dtype_); | |||||
| } | |||||
| void NormalizePadOp::Print(std::ostream &out) const { | |||||
| out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_PAD_OP_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_PAD_OP_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/kernels/tensor_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class NormalizePadOp : public TensorOp { | |||||
| public: | |||||
| NormalizePadOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b, | |||||
| std::string dtype = "float32"); | |||||
| ~NormalizePadOp() override = default; | |||||
| void Print(std::ostream &out) const override; | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| std::string Name() const override { return kNormalizePadOp; } | |||||
| private: | |||||
| std::shared_ptr<Tensor> mean_; | |||||
| std::shared_ptr<Tensor> std_; | |||||
| std::string dtype_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ | |||||
| @@ -62,6 +62,7 @@ constexpr char kHwcToChwOp[] = "HwcToChwOp"; | |||||
| constexpr char kInvertOp[] = "InvertOp"; | constexpr char kInvertOp[] = "InvertOp"; | ||||
| constexpr char kMixUpBatchOp[] = "MixUpBatchOp"; | constexpr char kMixUpBatchOp[] = "MixUpBatchOp"; | ||||
| constexpr char kNormalizeOp[] = "NormalizeOp"; | constexpr char kNormalizeOp[] = "NormalizeOp"; | ||||
| constexpr char kNormalizePadOp[] = "NormalizePadOp"; | |||||
| constexpr char kPadOp[] = "PadOp"; | constexpr char kPadOp[] = "PadOp"; | ||||
| constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; | constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; | ||||
| constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; | constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; | ||||
| @@ -50,8 +50,8 @@ import mindspore._c_dataengine as cde | |||||
| from .utils import Inter, Border, ImageBatchFormat | from .utils import Inter, Border, ImageBatchFormat | ||||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | ||||
| check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||||
| check_range, check_resize, check_rescale, check_pad, check_cutout, \ | |||||
| check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \ | |||||
| check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \ | |||||
| check_uniform_augment_cpp, \ | check_uniform_augment_cpp, \ | ||||
| check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ | check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ | ||||
| check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \ | check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \ | ||||
| @@ -319,6 +319,50 @@ class Normalize(cde.NormalizeOp): | |||||
| return img.as_array() | return img.as_array() | ||||
| class NormalizePad(cde.NormalizePadOp): | |||||
| """ | |||||
| Normalize the input image with respect to mean and standard deviation then pad an extra channel with value zero. | |||||
| Args: | |||||
| mean (sequence): List or tuple of mean values for each channel, with respect to channel order. | |||||
| The mean values must be in range (0.0, 255.0]. | |||||
| std (sequence): List or tuple of standard deviations for each channel, with respect to channel order. | |||||
| The standard deviation values must be in range (0.0, 255.0]. | |||||
| dtype (str): Set the output data type of normalized image (default is "float32"). | |||||
| Examples: | |||||
| >>> import mindspore.dataset.vision.c_transforms as c_vision | |||||
| >>> | |||||
| >>> decode_op = c_vision.Decode() | |||||
| >>> normalize_op = c_vision.NormalizePad(mean=[121.0, 115.0, 100.0], std=[70.0, 68.0, 71.0], dtype="float32") | |||||
| >>> transforms_list = [decode_op, normalize_pad_op] | |||||
| >>> data1 = data1.map(operations=transforms_list, input_columns=["image"]) | |||||
| """ | |||||
| @check_normalizepad_c | |||||
| def __init__(self, mean, std, dtype="float32"): | |||||
| self.mean = mean | |||||
| self.std = std | |||||
| self.dtype = dtype | |||||
| super().__init__(*mean, *std, dtype) | |||||
| def __call__(self, img): | |||||
| """ | |||||
| Call method. | |||||
| Args: | |||||
| img (NumPy or PIL image): Image array to be normalizepad. | |||||
| Returns: | |||||
| img (NumPy), NormalizePaded Image array. | |||||
| """ | |||||
| if not isinstance(img, (np.ndarray, Image.Image)): | |||||
| raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) | |||||
| normalize_pad = cde.Execute(cde.NormalizePadOp(*self.mean, *self.std, self.dtype)) | |||||
| img = normalize_pad(cde.Tensor(np.asarray(img))) | |||||
| return img.as_array() | |||||
| class RandomAffine(cde.RandomAffineOp): | class RandomAffine(cde.RandomAffineOp): | ||||
| """ | """ | ||||
| Apply Random affine transformation to the input image. | Apply Random affine transformation to the input image. | ||||
| @@ -28,7 +28,7 @@ from PIL import Image | |||||
| from . import py_transforms_util as util | from . import py_transforms_util as util | ||||
| from .c_transforms import parse_padding | from .c_transforms import parse_padding | ||||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | ||||
| check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||||
| check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||||
| check_ten_crop, check_num_channels, check_pad, \ | check_ten_crop, check_num_channels, check_pad, \ | ||||
| check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \ | check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \ | ||||
| check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast | check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast | ||||
| @@ -231,6 +231,49 @@ class Normalize: | |||||
| return util.normalize(img, self.mean, self.std) | return util.normalize(img, self.mean, self.std) | ||||
| class NormalizePad: | |||||
| """ | |||||
| Normalize the input NumPy image array of shape (C, H, W) with the given mean and standard deviation | |||||
| then pad an extra channel with value zero. | |||||
| The values of the array need to be in the range (0.0, 1.0]. | |||||
| Args: | |||||
| mean (sequence): List or tuple of mean values for each channel, with respect to channel order. | |||||
| The mean values must be in the range (0.0, 1.0]. | |||||
| std (sequence): List or tuple of standard deviations for each channel, w.r.t. channel order. | |||||
| The standard deviation values must be in the range (0.0, 1.0]. | |||||
| dtype (str): Set the output data type of image (default is "float32"). | |||||
| Examples: | |||||
| >>> import mindspore.dataset.vision.py_transforms as py_vision | |||||
| >>> from mindspore.dataset.transforms.py_transforms import Compose | |||||
| >>> | |||||
| >>> Compose([py_vision.Decode(), | |||||
| >>> py_vision.RandomHorizontalFlip(0.5), | |||||
| >>> py_vision.ToTensor(), | |||||
| >>> py_vision.NormalizePad((0.491, 0.482, 0.447), (0.247, 0.243, 0.262), "float32")]) | |||||
| """ | |||||
| @check_normalizepad_py | |||||
| def __init__(self, mean, std, dtype="float32"): | |||||
| self.mean = mean | |||||
| self.std = std | |||||
| self.dtype = dtype | |||||
| def __call__(self, img): | |||||
| """ | |||||
| Call method. | |||||
| Args: | |||||
| img (numpy.ndarray): Image array to be normalizepad. | |||||
| Returns: | |||||
| img (numpy.ndarray), NormalizePaded Image array. | |||||
| """ | |||||
| return util.normalize(img, self.mean, self.std, pad_channel=True, dtype=self.dtype) | |||||
| class RandomCrop: | class RandomCrop: | ||||
| """ | """ | ||||
| Crop the input PIL image at a random location. | Crop the input PIL image at a random location. | ||||
| @@ -42,7 +42,7 @@ def is_pil(img): | |||||
| return isinstance(img, Image.Image) | return isinstance(img, Image.Image) | ||||
| def normalize(img, mean, std): | |||||
| def normalize(img, mean, std, pad_channel=False, dtype="float32"): | |||||
| """ | """ | ||||
| Normalize the image between [0, 1] with respect to mean and standard deviation. | Normalize the image between [0, 1] with respect to mean and standard deviation. | ||||
| @@ -50,6 +50,8 @@ def normalize(img, mean, std): | |||||
| img (numpy.ndarray): Image array of shape CHW to be normalized. | img (numpy.ndarray): Image array of shape CHW to be normalized. | ||||
| mean (list): List of mean values for each channel, w.r.t channel order. | mean (list): List of mean values for each channel, w.r.t channel order. | ||||
| std (list): List of standard deviations for each channel, w.r.t. channel order. | std (list): List of standard deviations for each channel, w.r.t. channel order. | ||||
| pad_channel (bool): Whether to pad a extra channel with value zero. | |||||
| dtype (str): Output datatype of normalize, only worked when pad_channel is True. (default is "float32") | |||||
| Returns: | Returns: | ||||
| img (numpy.ndarray), Normalized image. | img (numpy.ndarray), Normalized image. | ||||
| @@ -72,7 +74,13 @@ def normalize(img, mean, std): | |||||
| mean = np.array(mean, dtype=img.dtype) | mean = np.array(mean, dtype=img.dtype) | ||||
| std = np.array(std, dtype=img.dtype) | std = np.array(std, dtype=img.dtype) | ||||
| return (img - mean[:, None, None]) / std[:, None, None] | |||||
| image = (img - mean[:, None, None]) / std[:, None, None] | |||||
| if pad_channel: | |||||
| zeros = np.zeros([1, image.shape[1], image.shape[2]], dtype=np.float32) | |||||
| image = np.concatenate((image, zeros), axis=0) | |||||
| if dtype == "float16": | |||||
| image = image.astype(np.float16) | |||||
| return image | |||||
| def decode(img): | def decode(img): | ||||
| @@ -294,6 +294,40 @@ def check_normalize_py(method): | |||||
| return new_method | return new_method | ||||
| def check_normalizepad_c(method): | |||||
| """A wrapper that wraps a parameter checker around the original function(normalizepad operation written in C++).""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs) | |||||
| check_normalize_c_param(mean, std) | |||||
| if not isinstance(dtype, str): | |||||
| raise TypeError("dtype should be string.") | |||||
| if dtype not in ["float32", "float16"]: | |||||
| raise ValueError("dtype only support float32 or float16.") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_normalizepad_py(method): | |||||
| """A wrapper that wraps a parameter checker around the original function(normalizepad operation written in Python).""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs) | |||||
| check_normalize_py_param(mean, std) | |||||
| if not isinstance(dtype, str): | |||||
| raise TypeError("dtype should be string.") | |||||
| if dtype not in ["float32", "float16"]: | |||||
| raise ValueError("dtype only support float32 or float16.") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_random_crop(method): | def check_random_crop(method): | ||||
| """Wrapper method to check the parameters of random crop.""" | """Wrapper method to check the parameters of random crop.""" | ||||
| @@ -58,11 +58,6 @@ class MyTimeMonitor(Callback): | |||||
| fps = self.batch_size / step_mseconds *1000 * self.size | fps = self.batch_size / step_mseconds *1000 * self.size | ||||
| print("Epoch time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ") | print("Epoch time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ") | ||||
| def pad(image): | |||||
| zeros = np.zeros([224, 224, 1], dtype=np.uint8) | |||||
| output = np.concatenate((image, zeros), axis=2) | |||||
| return output | |||||
| def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16"): | def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16"): | ||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True) | ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True) | ||||
| @@ -71,24 +66,25 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" | |||||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | ||||
| # define map operations | # define map operations | ||||
| normalize_op = C.Normalize(mean=mean, std=std) | |||||
| if dtype == "float16": | |||||
| normalize_op = C.NormalizePad(mean=mean, std=std, dtype="float16") | |||||
| if do_train: | if do_train: | ||||
| trans = [ | trans = [ | ||||
| C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | ||||
| C.RandomHorizontalFlip(prob=0.5), | C.RandomHorizontalFlip(prob=0.5), | ||||
| C.Normalize(mean=mean, std=std), | |||||
| normalize_op, | |||||
| ] | ] | ||||
| else: | else: | ||||
| trans = [ | trans = [ | ||||
| C.Decode(), | C.Decode(), | ||||
| C.Resize(256), | C.Resize(256), | ||||
| C.CenterCrop(image_size), | C.CenterCrop(image_size), | ||||
| C.Normalize(mean=mean, std=std), | |||||
| normalize_op, | |||||
| ] | ] | ||||
| if dtype == "fp32": | if dtype == "fp32": | ||||
| trans.append(C.HWC2CHW()) | trans.append(C.HWC2CHW()) | ||||
| ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=4) | |||||
| if dtype == "fp16": | |||||
| ds = ds.map(operations=pad, input_columns="image", num_parallel_workers=4) | |||||
| ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8) | |||||
| # apply batch operations | # apply batch operations | ||||
| ds = ds.batch(batch_size, drop_remainder=True) | ds = ds.batch(batch_size, drop_remainder=True) | ||||
| # apply dataset repeat operation | # apply dataset repeat operation | ||||
| @@ -932,6 +932,70 @@ TEST_F(MindDataTestPipeline, TestNormalizeFail) { | |||||
| EXPECT_EQ(normalize, nullptr); | EXPECT_EQ(normalize, nullptr); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestNormalizePad) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNormalizePad."; | |||||
| // Create an ImageFolder Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create a Repeat operation on ds | |||||
| int32_t repeat_num = 2; | |||||
| ds = ds->Repeat(repeat_num); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create objects for the tensor ops | |||||
| std::shared_ptr<TensorOperation> normalizepad = vision::NormalizePad({121.0, 115.0, 100.0}, {70.0, 68.0, 71.0}, | |||||
| "float32"); | |||||
| EXPECT_NE(normalizepad, nullptr); | |||||
| // Create a Map operation on ds | |||||
| ds = ds->Map({normalizepad}); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| EXPECT_EQ(image->shape()[2], 4); | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| EXPECT_EQ(i, 20); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestNormalizePadFail) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNormalizePadFail with invalid parameters."; | |||||
| // std value at 0.0 | |||||
| std::shared_ptr<TensorOperation> normalizepad = | |||||
| mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {0.0, 68.0, 71.0}); | |||||
| EXPECT_EQ(normalizepad, nullptr); | |||||
| // normalizepad with 2 values (not 3 values) for mean | |||||
| normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0}, {70.0, 68.0, 71.0}); | |||||
| EXPECT_EQ(normalizepad, nullptr); | |||||
| // normalizepad with 2 values (not 3 values) for standard deviation | |||||
| normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {68.0, 71.0}); | |||||
| EXPECT_EQ(normalizepad, nullptr); | |||||
| // normalizepad with invalid dtype | |||||
| normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {68.0, 71.0, 71.0}, "123"); | |||||
| EXPECT_EQ(normalizepad, nullptr); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestPad) { | TEST_F(MindDataTestPipeline, TestPad) { | ||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad."; | ||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common.h" | |||||
| #include "common/cvop_common.h" | |||||
| #include "minddata/dataset/kernels/image/normalize_pad_op.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include <opencv2/opencv.hpp> | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestNormalizePadOP : public UT::CVOP::CVOpCommon { | |||||
| public: | |||||
| MindDataTestNormalizePadOP() : CVOpCommon() {} | |||||
| }; | |||||
| TEST_F(MindDataTestNormalizePadOP, TestFloat32) { | |||||
| MS_LOG(INFO) << "Doing TestNormalizePadOp::TestFloat32."; | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| // Numbers are from the resnet50 model implementation | |||||
| float mean[3] = {121.0, 115.0, 100.0}; | |||||
| float std[3] = {70.0, 68.0, 71.0}; | |||||
| // NormalizePad Op | |||||
| std::unique_ptr<NormalizePadOp> op(new NormalizePadOp(mean[0], mean[1], mean[2], std[0], std[1], std[2], "float32")); | |||||
| EXPECT_TRUE(op->OneToOne()); | |||||
| Status s = op->Compute(input_tensor_, &output_tensor); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| } | |||||
| TEST_F(MindDataTestNormalizePadOP, TestFloat16) { | |||||
| MS_LOG(INFO) << "Doing TestNormalizePadOp::TestFloat16."; | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| // Numbers are from the resnet50 model implementation | |||||
| float mean[3] = {121.0, 115.0, 100.0}; | |||||
| float std[3] = {70.0, 68.0, 71.0}; | |||||
| // NormalizePad Op | |||||
| std::unique_ptr<NormalizePadOp> op(new NormalizePadOp(mean[0], mean[1], mean[2], std[0], std[1], std[2], "float16")); | |||||
| EXPECT_TRUE(op->OneToOne()); | |||||
| Status s = op->Compute(input_tensor_, &output_tensor); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| } | |||||
| @@ -0,0 +1,201 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| Testing Normalize op in DE | |||||
| """ | |||||
| import numpy as np | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.py_transforms | |||||
| import mindspore.dataset.vision.c_transforms as c_vision | |||||
| import mindspore.dataset.vision.py_transforms as py_vision | |||||
| from mindspore import log as logger | |||||
| from util import diff_mse, visualize_image | |||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| GENERATE_GOLDEN = False | |||||
| def normalizepad_np(image, mean, std): | |||||
| """ | |||||
| Apply the normalize+pad | |||||
| """ | |||||
| # DE decodes the image in RGB by deafult, hence | |||||
| # the values here are in RGB | |||||
| image = np.array(image, np.float32) | |||||
| image = image - np.array(mean) | |||||
| image = image * (1.0 / np.array(std)) | |||||
| zeros = np.zeros([image.shape[0], image.shape[1], 1], dtype=np.float32) | |||||
| output = np.concatenate((image, zeros), axis=2) | |||||
| return output | |||||
| def test_normalizepad_op_c(plot=False): | |||||
| """ | |||||
| Test NormalizePad in cpp transformations | |||||
| """ | |||||
| logger.info("Test Normalize in cpp") | |||||
| mean = [121.0, 115.0, 100.0] | |||||
| std = [70.0, 68.0, 71.0] | |||||
| # define map operations | |||||
| decode_op = c_vision.Decode() | |||||
| normalizepad_op = c_vision.NormalizePad(mean, std) | |||||
| # First dataset | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data1 = data1.map(operations=decode_op, input_columns=["image"]) | |||||
| data1 = data1.map(operations=normalizepad_op, input_columns=["image"]) | |||||
| # Second dataset | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = data2.map(operations=decode_op, input_columns=["image"]) | |||||
| num_iter = 0 | |||||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||||
| image_de_normalized = item1["image"] | |||||
| image_original = item2["image"] | |||||
| image_np_normalized = normalizepad_np(image_original, mean, std) | |||||
| mse = diff_mse(image_de_normalized, image_np_normalized) | |||||
| logger.info("image_{}, mse: {}".format(num_iter + 1, mse)) | |||||
| assert mse < 0.01 | |||||
| if plot: | |||||
| visualize_image(image_original, image_de_normalized, mse, image_np_normalized) | |||||
| num_iter += 1 | |||||
| def test_normalizepad_op_py(plot=False): | |||||
| """ | |||||
| Test NormalizePad in python transformations | |||||
| """ | |||||
| logger.info("Test Normalize in python") | |||||
| mean = [0.475, 0.45, 0.392] | |||||
| std = [0.275, 0.267, 0.278] | |||||
| # define map operations | |||||
| transforms = [ | |||||
| py_vision.Decode(), | |||||
| py_vision.ToTensor() | |||||
| ] | |||||
| transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) | |||||
| normalizepad_op = py_vision.NormalizePad(mean, std) | |||||
| # First dataset | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data1 = data1.map(operations=transform, input_columns=["image"]) | |||||
| data1 = data1.map(operations=normalizepad_op, input_columns=["image"]) | |||||
| # Second dataset | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = data2.map(operations=transform, input_columns=["image"]) | |||||
| num_iter = 0 | |||||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), | |||||
| data2.create_dict_iterator(num_epochs=1, output_numpy=True)): | |||||
| image_de_normalized = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||||
| image_np_normalized = (normalizepad_np(item2["image"].transpose(1, 2, 0), mean, std) * 255).astype(np.uint8) | |||||
| image_original = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||||
| mse = diff_mse(image_de_normalized, image_np_normalized) | |||||
| logger.info("image_{}, mse: {}".format(num_iter + 1, mse)) | |||||
| assert mse < 0.01 | |||||
| if plot: | |||||
| visualize_image(image_original, image_de_normalized, mse, image_np_normalized) | |||||
| num_iter += 1 | |||||
| def test_decode_normalizepad_op(): | |||||
| """ | |||||
| Test Decode op followed by NormalizePad op | |||||
| """ | |||||
| logger.info("Test [Decode, Normalize] in one Map") | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], num_parallel_workers=1, | |||||
| shuffle=False) | |||||
| # define map operations | |||||
| decode_op = c_vision.Decode() | |||||
| normalizepad_op = c_vision.NormalizePad([121.0, 115.0, 100.0], [70.0, 68.0, 71.0], "float16") | |||||
| # apply map operations on images | |||||
| data1 = data1.map(operations=[decode_op, normalizepad_op], input_columns=["image"]) | |||||
| num_iter = 0 | |||||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| logger.info("Looping inside iterator {}".format(num_iter)) | |||||
| assert item["image"].dtype == np.float16 | |||||
| num_iter += 1 | |||||
| def test_normalizepad_exception_unequal_size_c(): | |||||
| """ | |||||
| Test NormalizePad in c transformation: len(mean) != len(std) | |||||
| expected to raise ValueError | |||||
| """ | |||||
| logger.info("test_normalize_exception_unequal_size_c") | |||||
| try: | |||||
| _ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75, 75]) | |||||
| except ValueError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert str(e) == "Length of mean and std must be equal." | |||||
| try: | |||||
| _ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75], 1) | |||||
| except TypeError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert str(e) == "dtype should be string." | |||||
| try: | |||||
| _ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75], "") | |||||
| except ValueError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert str(e) == "dtype only support float32 or float16." | |||||
| def test_normalizepad_exception_unequal_size_py(): | |||||
| """ | |||||
| Test NormalizePad in python transformation: len(mean) != len(std) | |||||
| expected to raise ValueError | |||||
| """ | |||||
| logger.info("test_normalizepad_exception_unequal_size_py") | |||||
| try: | |||||
| _ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71, 0.72]) | |||||
| except ValueError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert str(e) == "Length of mean and std must be equal." | |||||
| try: | |||||
| _ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71], 1) | |||||
| except TypeError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert str(e) == "dtype should be string." | |||||
| try: | |||||
| _ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71], "") | |||||
| except ValueError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert str(e) == "dtype only support float32 or float16." | |||||
| def test_normalizepad_exception_invalid_range_py(): | |||||
| """ | |||||
| Test NormalizePad in python transformation: value is not in range [0,1] | |||||
| expected to raise ValueError | |||||
| """ | |||||
| logger.info("test_normalizepad_exception_invalid_range_py") | |||||
| try: | |||||
| _ = py_vision.NormalizePad([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) | |||||
| except ValueError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e) | |||||