Merge pull request !4367 from MahdiRahmaniHanzaki/cutmixtags/v0.7.0-beta
| @@ -110,5 +110,12 @@ PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) { | |||
| .export_values(); | |||
| })); | |||
| PYBIND_REGISTER(ImageBatchFormat, 0, ([](const py::module *m) { | |||
| (void)py::enum_<ImageBatchFormat>(*m, "ImageBatchFormat", py::arithmetic()) | |||
| .value("DE_IMAGE_BATCH_FORMAT_NHWC", ImageBatchFormat::kNHWC) | |||
| .value("DE_IMAGE_BATCH_FORMAT_NCHW", ImageBatchFormat::kNCHW) | |||
| .export_values(); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -22,6 +22,7 @@ | |||
| #include "minddata/dataset/kernels/image/auto_contrast_op.h" | |||
| #include "minddata/dataset/kernels/image/bounding_box_augment_op.h" | |||
| #include "minddata/dataset/kernels/image/center_crop_op.h" | |||
| #include "minddata/dataset/kernels/image/cutmix_batch_op.h" | |||
| #include "minddata/dataset/kernels/image/cut_out_op.h" | |||
| #include "minddata/dataset/kernels/image/decode_op.h" | |||
| #include "minddata/dataset/kernels/image/equalize_op.h" | |||
| @@ -105,6 +106,13 @@ PYBIND_REGISTER(MixUpBatchOp, 1, ([](const py::module *m) { | |||
| .def(py::init<float>(), py::arg("alpha")); | |||
| })); | |||
| PYBIND_REGISTER(CutMixBatchOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<CutMixBatchOp, TensorOp, std::shared_ptr<CutMixBatchOp>>( | |||
| *m, "CutMixBatchOp", "Tensor operation to cutmix a batch of images") | |||
| .def(py::init<ImageBatchFormat, float, float>(), py::arg("image_batch_format"), py::arg("alpha"), | |||
| py::arg("prob")); | |||
| })); | |||
| PYBIND_REGISTER(ResizeOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>( | |||
| *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") | |||
| @@ -19,6 +19,7 @@ | |||
| #include "minddata/dataset/kernels/image/center_crop_op.h" | |||
| #include "minddata/dataset/kernels/image/crop_op.h" | |||
| #include "minddata/dataset/kernels/image/cutmix_batch_op.h" | |||
| #include "minddata/dataset/kernels/image/cut_out_op.h" | |||
| #include "minddata/dataset/kernels/image/decode_op.h" | |||
| #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" | |||
| @@ -70,6 +71,16 @@ std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vecto | |||
| return op; | |||
| } | |||
| // Function to create CutMixBatchOperation. | |||
| std::shared_ptr<CutMixBatchOperation> CutMixBatch(ImageBatchFormat image_batch_format, float alpha, float prob) { | |||
| auto op = std::make_shared<CutMixBatchOperation>(image_batch_format, alpha, prob); | |||
| // Input validation | |||
| if (!op->ValidateParams()) { | |||
| return nullptr; | |||
| } | |||
| return op; | |||
| } | |||
| // Function to create CutOutOp. | |||
| std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches) { | |||
| auto op = std::make_shared<CutOutOperation>(length, num_patches); | |||
| @@ -355,6 +366,27 @@ std::shared_ptr<TensorOp> CropOperation::Build() { | |||
| return tensor_op; | |||
| } | |||
| // CutMixBatchOperation | |||
| CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha, float prob) | |||
| : image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {} | |||
| bool CutMixBatchOperation::ValidateParams() { | |||
| if (alpha_ < 0) { | |||
| MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative."; | |||
| return false; | |||
| } | |||
| if (prob_ < 0 || prob_ > 1) { | |||
| MS_LOG(ERROR) << "CutMixBatch: Probability has to be between 0 and 1."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| std::shared_ptr<TensorOp> CutMixBatchOperation::Build() { | |||
| std::shared_ptr<CutMixBatchOp> tensor_op = std::make_shared<CutMixBatchOp>(image_batch_format_, alpha_, prob_); | |||
| return tensor_op; | |||
| } | |||
| // CutOutOperation | |||
| CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {} | |||
| @@ -41,6 +41,12 @@ enum class ShuffleMode { kFalse = 0, kFiles = 1, kGlobal = 2 }; | |||
| // Possible values for Border types | |||
| enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 }; | |||
| // Possible values for Image format types in a batch | |||
| enum class ImageBatchFormat { kNHWC = 0, kNCHW = 1 }; | |||
| // Possible values for Image format types | |||
| enum class ImageFormat { HWC = 0, CHW = 1, HW = 2 }; | |||
| // Possible interpolation modes | |||
| enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3 }; | |||
| @@ -49,6 +49,7 @@ namespace vision { | |||
| // Transform Op classes (in alphabetical order) | |||
| class CenterCropOperation; | |||
| class CropOperation; | |||
| class CutMixBatchOperation; | |||
| class CutOutOperation; | |||
| class DecodeOperation; | |||
| class HwcToChwOperation; | |||
| @@ -86,6 +87,16 @@ std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size); | |||
| /// \return Shared pointer to the current TensorOp | |||
| std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size); | |||
| /// \brief Function to apply CutMix on a batch of images | |||
| /// \notes Masks a random section of each image with the corresponding part of another randomly selected image in | |||
| /// that batch | |||
| /// \param[in] image_batch_format The format of the batch | |||
| /// \param[in] alpha The hyperparameter of beta distribution (default = 1.0) | |||
| /// \param[in] prob The probability by which CutMix is applied to each image (default = 1.0) | |||
| /// \return Shared pointer to the current TensorOp | |||
| std::shared_ptr<CutMixBatchOperation> CutMixBatch(ImageBatchFormat image_batch_format, float alpha = 1.0, | |||
| float prob = 1.0); | |||
| /// \brief Function to create a CutOut TensorOp | |||
| /// \notes Randomly cut (mask) out a given number of square patches from the input image | |||
| /// \param[in] length Integer representing the side length of each square patch | |||
| @@ -305,6 +316,22 @@ class CropOperation : public TensorOperation { | |||
| std::vector<int32_t> size_; | |||
| }; | |||
| class CutMixBatchOperation : public TensorOperation { | |||
| public: | |||
| explicit CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha = 1.0, float prob = 1.0); | |||
| ~CutMixBatchOperation() = default; | |||
| std::shared_ptr<TensorOp> Build() override; | |||
| bool ValidateParams() override; | |||
| private: | |||
| float alpha_; | |||
| float prob_; | |||
| ImageBatchFormat image_batch_format_; | |||
| }; | |||
| class CutOutOperation : public TensorOperation { | |||
| public: | |||
| explicit CutOutOperation(int32_t length, int32_t num_patches = 1); | |||
| @@ -318,6 +345,7 @@ class CutOutOperation : public TensorOperation { | |||
| private: | |||
| int32_t length_; | |||
| int32_t num_patches_; | |||
| ImageBatchFormat image_batch_format_; | |||
| }; | |||
| class DecodeOperation : public TensorOperation { | |||
| @@ -655,7 +655,7 @@ Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input, | |||
| TensorShape remaining({-1}); | |||
| std::vector<int64_t> index(tensor_shape.size(), 0); | |||
| if (tensor_shape.size() <= 1) { | |||
| RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack"); | |||
| RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack."); | |||
| } | |||
| TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end())); | |||
| @@ -664,15 +664,48 @@ Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input, | |||
| std::shared_ptr<Tensor> out; | |||
| RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining)); | |||
| RETURN_IF_NOT_OK(input->CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out)); | |||
| std::shared_ptr<CVTensor> cv_out = CVTensor::AsCVTensor(std::move(out)); | |||
| if (!cv_out->mat().data) { | |||
| RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); | |||
| RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor."); | |||
| } | |||
| output->push_back(cv_out); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status BatchTensorToTensorVector(const std::shared_ptr<Tensor> &input, std::vector<std::shared_ptr<Tensor>> *output) { | |||
| std::vector<int64_t> tensor_shape = input->shape().AsVector(); | |||
| TensorShape remaining({-1}); | |||
| std::vector<int64_t> index(tensor_shape.size(), 0); | |||
| if (tensor_shape.size() <= 1) { | |||
| RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack."); | |||
| } | |||
| TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end())); | |||
| for (; index[0] < tensor_shape[0]; index[0]++) { | |||
| uchar *start_addr_of_index = nullptr; | |||
| std::shared_ptr<Tensor> out; | |||
| RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out)); | |||
| output->push_back(out); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &input, std::shared_ptr<Tensor> *output) { | |||
| if (input.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("TensorVectorToBatchTensor: Received an empty vector."); | |||
| } | |||
| std::vector<int64_t> tensor_shape = input.front()->shape().AsVector(); | |||
| tensor_shape.insert(tensor_shape.begin(), input.size()); | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(tensor_shape), input.at(0)->type(), output)); | |||
| for (int i = 0; i < input.size(); i++) { | |||
| RETURN_IF_NOT_OK((*output)->InsertTensor({i}, input[i])); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -158,11 +158,24 @@ Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<T | |||
| std::shared_ptr<Tensor> append); | |||
| /// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional CVTensors | |||
| /// @param input[in] input tensor | |||
| /// @param output[out] output tensor | |||
| /// @return Status ok/error | |||
| /// \param input[in] input tensor | |||
| /// \param output[out] output vector of CVTensors | |||
| /// \return Status ok/error | |||
| Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input, | |||
| std::vector<std::shared_ptr<CVTensor>> *output); | |||
| /// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional Tensors | |||
| /// \param input[in] input tensor | |||
| /// \param output[out] output vector of tensors | |||
| /// \return Status ok/error | |||
| Status BatchTensorToTensorVector(const std::shared_ptr<Tensor> &input, std::vector<std::shared_ptr<Tensor>> *output); | |||
| /// Convert a vector of (n-1)-dimensional Tensors to an n-dimensional Tensor | |||
| /// \param input[in] input vector of tensors | |||
| /// \param output[out] output tensor | |||
| /// \return Status ok/error | |||
| Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &input, std::shared_ptr<Tensor> *output); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -7,6 +7,7 @@ add_library(kernels-image OBJECT | |||
| center_crop_op.cc | |||
| crop_op.cc | |||
| cut_out_op.cc | |||
| cutmix_batch_op.cc | |||
| decode_op.cc | |||
| equalize_op.cc | |||
| hwc_to_chw_op.cc | |||
| @@ -0,0 +1,166 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <utility> | |||
| #include "minddata/dataset/core/cv_tensor.h" | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #include "minddata/dataset/kernels/image/cutmix_batch_op.h" | |||
| #include "minddata/dataset/kernels/data/data_utils.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CutMixBatchOp::CutMixBatchOp(ImageBatchFormat image_batch_format, float alpha, float prob) | |||
| : image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) { | |||
| rnd_.seed(GetSeed()); | |||
| } | |||
| void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y, int *crop_width, int *crop_height) { | |||
| float cut_ratio = 1 - lam; | |||
| int cut_w = static_cast<int>(width * cut_ratio); | |||
| int cut_h = static_cast<int>(height * cut_ratio); | |||
| std::uniform_int_distribution<int> width_uniform_distribution(0, width); | |||
| std::uniform_int_distribution<int> height_uniform_distribution(0, height); | |||
| int cx = width_uniform_distribution(rnd_); | |||
| int x2, y2; | |||
| int cy = height_uniform_distribution(rnd_); | |||
| *x = std::clamp(cx - cut_w / 2, 0, width - 1); // horizontal coordinate of left side of crop box | |||
| *y = std::clamp(cy - cut_h / 2, 0, height - 1); // vertical coordinate of the top side of crop box | |||
| x2 = std::clamp(cx + cut_w / 2, 0, width - 1); // horizontal coordinate of right side of crop box | |||
| y2 = std::clamp(cy + cut_h / 2, 0, height - 1); // vertical coordinate of the bottom side of crop box | |||
| *crop_width = std::clamp(x2 - *x, 1, width - 1); | |||
| *crop_height = std::clamp(y2 - *y, 1, height - 1); | |||
| } | |||
| Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| if (input.size() < 2) { | |||
| RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation"); | |||
| } | |||
| std::vector<std::shared_ptr<Tensor>> images; | |||
| std::vector<int64_t> image_shape = input.at(0)->shape().AsVector(); | |||
| std::vector<int64_t> label_shape = input.at(1)->shape().AsVector(); | |||
| // Check inputs | |||
| if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { | |||
| RETURN_STATUS_UNEXPECTED("You must batch before calling CutMixBatch."); | |||
| } | |||
| if (label_shape.size() != 2) { | |||
| RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch"); | |||
| } | |||
| if ((image_shape[1] != 1 && image_shape[1] != 3) && image_batch_format_ == ImageBatchFormat::kNCHW) { | |||
| RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format."); | |||
| } | |||
| if ((image_shape[3] != 1 && image_shape[3] != 3) && image_batch_format_ == ImageBatchFormat::kNHWC) { | |||
| RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format."); | |||
| } | |||
| // Move images into a vector of Tensors | |||
| RETURN_IF_NOT_OK(BatchTensorToTensorVector(input.at(0), &images)); | |||
| // Calculate random labels | |||
| std::vector<int64_t> rand_indx; | |||
| for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i); | |||
| std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_); | |||
| std::gamma_distribution<float> gamma_distribution(alpha_, 1); | |||
| std::uniform_real_distribution<double> uniform_distribution(0.0, 1.0); | |||
| // Tensor holding the output labels | |||
| std::shared_ptr<Tensor> out_labels; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(label_shape), DataType(DataType::DE_FLOAT32), &out_labels)); | |||
| // Compute labels and images | |||
| for (int i = 0; i < image_shape[0]; i++) { | |||
| // Calculating lambda | |||
| // If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1) | |||
| // then x = x1 / (x1+x2) is a random variable from Beta(a1, a2) | |||
| float x1 = gamma_distribution(rnd_); | |||
| float x2 = gamma_distribution(rnd_); | |||
| float lam = x1 / (x1 + x2); | |||
| double random_number = uniform_distribution(rnd_); | |||
| if (random_number < prob_) { | |||
| int x, y, crop_width, crop_height; | |||
| float label_lam; // lambda used for labels | |||
| // Get a random image | |||
| TensorShape remaining({-1}); | |||
| uchar *start_addr_of_index = nullptr; | |||
| std::shared_ptr<Tensor> rand_image; | |||
| RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}), | |||
| input.at(0)->type(), start_addr_of_index, &rand_image)); | |||
| // Compute image | |||
| if (image_batch_format_ == ImageBatchFormat::kNHWC) { | |||
| // NHWC Format | |||
| GetCropBox(static_cast<int32_t>(image_shape[1]), static_cast<int32_t>(image_shape[2]), lam, &x, &y, &crop_width, | |||
| &crop_height); | |||
| std::shared_ptr<Tensor> cropped; | |||
| RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height)); | |||
| RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::HWC)); | |||
| label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[1] * image_shape[2])); | |||
| } else { | |||
| // NCHW Format | |||
| GetCropBox(static_cast<int32_t>(image_shape[2]), static_cast<int32_t>(image_shape[3]), lam, &x, &y, &crop_width, | |||
| &crop_height); | |||
| std::vector<std::shared_ptr<Tensor>> channels; // A vector holding channels of the CHW image | |||
| std::vector<std::shared_ptr<Tensor>> cropped_channels; // A vector holding the channels of the cropped CHW | |||
| RETURN_IF_NOT_OK(BatchTensorToTensorVector(rand_image, &channels)); | |||
| for (auto channel : channels) { | |||
| // Call crop for each single channel | |||
| std::shared_ptr<Tensor> cropped_channel; | |||
| RETURN_IF_NOT_OK(Crop(channel, &cropped_channel, x, y, crop_width, crop_height)); | |||
| cropped_channels.push_back(cropped_channel); | |||
| } | |||
| std::shared_ptr<Tensor> cropped; | |||
| // Merge channels to a single tensor | |||
| RETURN_IF_NOT_OK(TensorVectorToBatchTensor(cropped_channels, &cropped)); | |||
| RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::CHW)); | |||
| label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[2] * image_shape[3])); | |||
| } | |||
| // Compute labels | |||
| for (int j = 0; j < label_shape[1]; j++) { | |||
| uint64_t first_value, second_value; | |||
| RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); | |||
| RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j})); | |||
| RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value)); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<Tensor> out_images; | |||
| RETURN_IF_NOT_OK(TensorVectorToBatchTensor(images, &out_images)); | |||
| // Move the output into a TensorRow | |||
| output->push_back(out_images); | |||
| output->push_back(out_labels); | |||
| return Status::OK(); | |||
| } | |||
| void CutMixBatchOp::Print(std::ostream &out) const { | |||
| out << "CutMixBatchOp: " | |||
| << "image_batch_format: " << image_batch_format_ << "alpha: " << alpha_ << ", probability: " << prob_ << "\n"; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <random> | |||
| #include <string> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CutMixBatchOp : public TensorOp { | |||
| public: | |||
| explicit CutMixBatchOp(ImageBatchFormat image_batch_format, float alpha, float prob); | |||
| ~CutMixBatchOp() override = default; | |||
| void Print(std::ostream &out) const override; | |||
| void GetCropBox(int width, int height, float lam, int *x, int *y, int *crop_width, int *crop_height); | |||
| Status Compute(const TensorRow &input, TensorRow *output) override; | |||
| std::string Name() const override { return kCutMixBatchOp; } | |||
| private: | |||
| float alpha_; | |||
| float prob_; | |||
| ImageBatchFormat image_batch_format_; | |||
| std::mt19937 rnd_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_ | |||
| @@ -402,6 +402,62 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) | |||
| } | |||
| } | |||
| Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Tensor> *input, int x, int y, | |||
| int crop_width, int crop_height, ImageFormat image_format) { | |||
| if (image_format == ImageFormat::HWC) { | |||
| if ((*input)->Rank() != 3 || ((*input)->shape()[2] != 1 && (*input)->shape()[2] != 3)) { | |||
| RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format."); | |||
| } | |||
| if (sub_mat->Rank() != 3 || (sub_mat->shape()[2] != 1 && sub_mat->shape()[2] != 3)) { | |||
| RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format."); | |||
| } | |||
| int number_of_channels = (*input)->shape()[2]; | |||
| for (int i = 0; i < crop_width; i++) { | |||
| for (int j = 0; j < crop_height; j++) { | |||
| for (int c = 0; c < number_of_channels; c++) { | |||
| uint8_t pixel_value; | |||
| RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i, c})); | |||
| RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i, c}, pixel_value)); | |||
| } | |||
| } | |||
| } | |||
| } else if (image_format == ImageFormat::CHW) { | |||
| if ((*input)->Rank() != 3 || ((*input)->shape()[0] != 1 && (*input)->shape()[0] != 3)) { | |||
| RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format."); | |||
| } | |||
| if (sub_mat->Rank() != 3 || (sub_mat->shape()[0] != 1 && sub_mat->shape()[0] != 3)) { | |||
| RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format."); | |||
| } | |||
| int number_of_channels = (*input)->shape()[0]; | |||
| for (int i = 0; i < crop_width; i++) { | |||
| for (int j = 0; j < crop_height; j++) { | |||
| for (int c = 0; c < number_of_channels; c++) { | |||
| uint8_t pixel_value; | |||
| RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {c, j, i})); | |||
| RETURN_IF_NOT_OK((*input)->SetItemAt({c, y + j, x + i}, pixel_value)); | |||
| } | |||
| } | |||
| } | |||
| } else if (image_format == ImageFormat::HW) { | |||
| if ((*input)->Rank() != 2) { | |||
| RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format."); | |||
| } | |||
| if (sub_mat->Rank() != 2) { | |||
| RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format."); | |||
| } | |||
| for (int i = 0; i < crop_width; i++) { | |||
| for (int j = 0; j < crop_height; j++) { | |||
| uint8_t pixel_value; | |||
| RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i})); | |||
| RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i}, pixel_value)); | |||
| } | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image format must be CHW, HWC, or HW."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SwapRedAndBlue(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) { | |||
| try { | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input)); | |||
| @@ -120,6 +120,19 @@ Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu | |||
| /// \param output: Tensor of shape <C,H,W> or <H,W> and same input type. | |||
| Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output); | |||
| /// \brief Masks the given part of the input image with a another image (sub_mat) | |||
| /// \param[in] sub_mat The image we want to mask with | |||
| /// \param[in] input The pointer to the image we want to mask | |||
| /// \param[in] x The horizontal coordinate of left side of crop box | |||
| /// \param[in] y The vertical coordinate of the top side of crop box | |||
| /// \param[in] width The width of the mask box | |||
| /// \param[in] height The height of the mask box | |||
| /// \param[in] image_format The format of the image (CHW or HWC) | |||
| /// \param[out] input Masks the input image in-place and returns it | |||
| /// @return Status ok/error | |||
| Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Tensor> *input, int x, int y, int width, | |||
| int height, ImageFormat image_format); | |||
| /// \brief Swap the red and blue pixels (RGB <-> BGR) | |||
| /// \param input: Tensor of shape <H,W,3> and any OpenCv compatible type, see CVTensor. | |||
| /// \param output: Swapped image of same shape and type | |||
| @@ -37,10 +37,12 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| std::vector<int64_t> label_shape = input.at(1)->shape().AsVector(); | |||
| // Check inputs | |||
| if (label_shape.size() != 2 || image_shape.size() != 4 || image_shape[0] != label_shape[0]) { | |||
| if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { | |||
| RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch"); | |||
| } | |||
| if (label_shape.size() != 2) { | |||
| RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch"); | |||
| } | |||
| if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) { | |||
| RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW"); | |||
| } | |||
| @@ -94,6 +94,7 @@ constexpr char kAutoContrastOp[] = "AutoContrastOp"; | |||
| constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; | |||
| constexpr char kDecodeOp[] = "DecodeOp"; | |||
| constexpr char kCenterCropOp[] = "CenterCropOp"; | |||
| constexpr char kCutMixBatchOp[] = "CutMixBatchOp"; | |||
| constexpr char kCutOutOp[] = "CutOutOp"; | |||
| constexpr char kCropOp[] = "CropOp"; | |||
| constexpr char kEqualizeOp[] = "EqualizeOp"; | |||
| @@ -43,13 +43,14 @@ Examples: | |||
| import numbers | |||
| import mindspore._c_dataengine as cde | |||
| from .utils import Inter, Border | |||
| from .utils import Inter, Border, ImageBatchFormat | |||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | |||
| check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||
| check_range, check_resize, check_rescale, check_pad, check_cutout, \ | |||
| check_uniform_augment_cpp, \ | |||
| check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ | |||
| check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER | |||
| check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \ | |||
| check_cut_mix_batch_c | |||
| DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, | |||
| Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, | |||
| @@ -60,6 +61,8 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT, | |||
| Border.REFLECT: cde.BorderType.DE_BORDER_REFLECT, | |||
| Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC} | |||
| DE_C_IMAGE_BATCH_FORMAT = {ImageBatchFormat.NHWC: cde.ImageBatchFormat.DE_IMAGE_BATCH_FORMAT_NHWC, | |||
| ImageBatchFormat.NCHW: cde.ImageBatchFormat.DE_IMAGE_BATCH_FORMAT_NCHW} | |||
| def parse_padding(padding): | |||
| if isinstance(padding, numbers.Number): | |||
| @@ -143,6 +146,33 @@ class Decode(cde.DecodeOp): | |||
| super().__init__(self.rgb) | |||
| class CutMixBatch(cde.CutMixBatchOp): | |||
| """ | |||
| Apply CutMix transformation on input batch of images and labels. | |||
| Note that you need to make labels into one-hot format and batch before calling this function. | |||
| Args: | |||
| image_batch_format (Image Batch Format): The method of padding. Can be any of | |||
| [ImageBatchFormat.NHWC, ImageBatchFormat.NCHW] | |||
| alpha (float): hyperparameter of beta distribution (default = 1.0). | |||
| prob (float): The probability by which CutMix is applied to each image (default = 1.0). | |||
| Examples: | |||
| >>> one_hot_op = data.OneHot(num_classes=10) | |||
| >>> data = data.map(input_columns=["label"], operations=one_hot_op) | |||
| >>> cutmix_batch_op = vision.CutMixBatch(ImageBatchFormat.NHWC, 1.0, 0.5) | |||
| >>> data = data.batch(5) | |||
| >>> data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op) | |||
| """ | |||
| @check_cut_mix_batch_c | |||
| def __init__(self, image_batch_format, alpha=1.0, prob=1.0): | |||
| self.image_batch_format = image_batch_format.value | |||
| self.alpha = alpha | |||
| self.prob = prob | |||
| super().__init__(DE_C_IMAGE_BATCH_FORMAT[image_batch_format], alpha, prob) | |||
| class CutOut(cde.CutOutOp): | |||
| """ | |||
| Randomly cut (mask) out a given number of square patches from the input Numpy image array. | |||
| @@ -30,3 +30,9 @@ class Border(str, Enum): | |||
| EDGE: str = "edge" | |||
| REFLECT: str = "reflect" | |||
| SYMMETRIC: str = "symmetric" | |||
| # Image Batch Format | |||
| class ImageBatchFormat(IntEnum): | |||
| NHWC = 0 | |||
| NCHW = 1 | |||
| @@ -19,7 +19,7 @@ from functools import wraps | |||
| import numpy as np | |||
| from mindspore._c_dataengine import TensorOp | |||
| from .utils import Inter, Border | |||
| from .utils import Inter, Border, ImageBatchFormat | |||
| from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ | |||
| check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \ | |||
| check_tensor_op, UINT8_MAX | |||
| @@ -37,6 +37,20 @@ def check_crop_size(size): | |||
| raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") | |||
| def check_cut_mix_batch_c(method): | |||
| """Wrapper method to check the parameters of CutMixBatch.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs) | |||
| type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format") | |||
| check_pos_float32(alpha) | |||
| check_value(prob, [0, 1], "prob") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_resize_size(size): | |||
| """Wrapper method to check the parameters of resize.""" | |||
| if isinstance(size, int): | |||
| @@ -20,6 +20,7 @@ SET(DE_UT_SRCS | |||
| circular_pool_test.cc | |||
| client_config_test.cc | |||
| connector_test.cc | |||
| cutmix_batch_op_test.cc | |||
| cut_out_op_test.cc | |||
| datatype_test.cc | |||
| decode_op_test.cc | |||
| @@ -25,6 +25,177 @@ class MindDataTestPipeline : public UT::DatasetOpTesting { | |||
| protected: | |||
| }; | |||
| TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) { | |||
| // Testing CutMixBatch on a batch of CHW images | |||
| // Create a Cifar10 Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| int number_of_classes = 10; | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorOperation> hwc_to_chw = vision::HWC2CHW(); | |||
| EXPECT_NE(hwc_to_chw, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({hwc_to_chw},{"image"}); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create a Batch operation on ds | |||
| int32_t batch_size = 5; | |||
| ds = ds->Batch(batch_size); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(number_of_classes); | |||
| EXPECT_NE(one_hot_op, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({one_hot_op},{"label"}); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0); | |||
| EXPECT_NE(cutmix_batch_op, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({cutmix_batch_op}, {"image", "label"}); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| auto label = row["label"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||
| MS_LOG(INFO) << "Label shape: " << label->shape(); | |||
| EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1] | |||
| && 32 == image->shape()[2] && 32 == image->shape()[3], true); | |||
| EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] && | |||
| number_of_classes == label->shape()[1], true); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) { | |||
| // Calling CutMixBatch on a batch of HWC images with default values of alpha and prob | |||
| // Create a Cifar10 Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| int number_of_classes = 10; | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create a Batch operation on ds | |||
| int32_t batch_size = 5; | |||
| ds = ds->Batch(batch_size); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(number_of_classes); | |||
| EXPECT_NE(one_hot_op, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({one_hot_op},{"label"}); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC); | |||
| EXPECT_NE(cutmix_batch_op, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({cutmix_batch_op}, {"image", "label"}); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| auto label = row["label"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||
| MS_LOG(INFO) << "Label shape: " << label->shape(); | |||
| EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1] | |||
| && 32 == image->shape()[2] && 3 == image->shape()[3], true); | |||
| EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] && | |||
| number_of_classes == label->shape()[1], true); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) { | |||
| // Must fail because alpha can't be negative | |||
| // Create a Cifar10 Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create a Batch operation on ds | |||
| int32_t batch_size = 5; | |||
| ds = ds->Batch(batch_size); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10); | |||
| EXPECT_NE(one_hot_op, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({one_hot_op},{"label"}); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5); | |||
| EXPECT_EQ(cutmix_batch_op, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) { | |||
| // Must fail because prob can't be negative | |||
| // Create a Cifar10 Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create a Batch operation on ds | |||
| int32_t batch_size = 5; | |||
| ds = ds->Batch(batch_size); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10); | |||
| EXPECT_NE(one_hot_op, nullptr); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({one_hot_op},{"label"}); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5); | |||
| EXPECT_EQ(cutmix_batch_op, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCutOut) { | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "common/cvop_common.h" | |||
| #include "minddata/dataset/kernels/image/cutmix_batch_op.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::LogStream; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::MsLogLevel::INFO; | |||
| class MindDataTestCutMixBatchOp : public UT::CVOP::CVOpCommon { | |||
| protected: | |||
| MindDataTestCutMixBatchOp() : CVOpCommon() {} | |||
| }; | |||
| TEST_F(MindDataTestCutMixBatchOp, TestSuccess1) { | |||
| MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp success1 case"; | |||
| std::shared_ptr<Tensor> batched_tensor; | |||
| std::shared_ptr<Tensor> batched_labels; | |||
| Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}), | |||
| input_tensor_->type(), &batched_tensor); | |||
| for (int i = 0; i < 2; i++) { | |||
| batched_tensor->InsertTensor({i}, input_tensor_); | |||
| } | |||
| Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels); | |||
| std::shared_ptr<CutMixBatchOp> op = std::make_shared<CutMixBatchOp>(ImageBatchFormat::kNHWC, 1.0, 1.0); | |||
| TensorRow in; | |||
| in.push_back(batched_tensor); | |||
| in.push_back(batched_labels); | |||
| TensorRow out; | |||
| ASSERT_TRUE(op->Compute(in, &out).IsOk()); | |||
| EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]); | |||
| EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]); | |||
| EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]); | |||
| EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]); | |||
| EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]); | |||
| EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]); | |||
| } | |||
| TEST_F(MindDataTestCutMixBatchOp, TestSuccess2) { | |||
| MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp success2 case"; | |||
| std::shared_ptr<Tensor> batched_tensor; | |||
| std::shared_ptr<Tensor> batched_labels; | |||
| std::shared_ptr<Tensor> chw_tensor; | |||
| ASSERT_TRUE(HwcToChw(input_tensor_, &chw_tensor).IsOk()); | |||
| Tensor::CreateEmpty(TensorShape({2, chw_tensor->shape()[0], chw_tensor->shape()[1], chw_tensor->shape()[2]}), | |||
| chw_tensor->type(), &batched_tensor); | |||
| for (int i = 0; i < 2; i++) { | |||
| batched_tensor->InsertTensor({i}, chw_tensor); | |||
| } | |||
| Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels); | |||
| std::shared_ptr<CutMixBatchOp> op = std::make_shared<CutMixBatchOp>(ImageBatchFormat::kNCHW, 1.0, 0.5); | |||
| TensorRow in; | |||
| in.push_back(batched_tensor); | |||
| in.push_back(batched_labels); | |||
| TensorRow out; | |||
| ASSERT_TRUE(op->Compute(in, &out).IsOk()); | |||
| EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]); | |||
| EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]); | |||
| EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]); | |||
| EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]); | |||
| EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]); | |||
| EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]); | |||
| } | |||
| TEST_F(MindDataTestCutMixBatchOp, TestFail1) { | |||
| // This is a fail case because our labels are not batched and are 1-dimensional | |||
| MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp fail1 case"; | |||
| std::shared_ptr<Tensor> labels; | |||
| Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({4}), &labels); | |||
| std::shared_ptr<CutMixBatchOp> op = std::make_shared<CutMixBatchOp>(ImageBatchFormat::kNHWC, 1.0, 1.0); | |||
| TensorRow in; | |||
| in.push_back(input_tensor_); | |||
| in.push_back(labels); | |||
| TensorRow out; | |||
| ASSERT_FALSE(op->Compute(in, &out).IsOk()); | |||
| } | |||
| TEST_F(MindDataTestCutMixBatchOp, TestFail2) { | |||
| // This should fail because the image_batch_format provided is not the same as the actual format of the images | |||
| MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp fail2 case"; | |||
| std::shared_ptr<Tensor> batched_tensor; | |||
| std::shared_ptr<Tensor> batched_labels; | |||
| Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}), | |||
| input_tensor_->type(), &batched_tensor); | |||
| for (int i = 0; i < 2; i++) { | |||
| batched_tensor->InsertTensor({i}, input_tensor_); | |||
| } | |||
| Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels); | |||
| std::shared_ptr<CutMixBatchOp> op = std::make_shared<CutMixBatchOp>(ImageBatchFormat::kNCHW, 1.0, 1.0); | |||
| TensorRow in; | |||
| in.push_back(batched_tensor); | |||
| in.push_back(batched_labels); | |||
| TensorRow out; | |||
| ASSERT_FALSE(op->Compute(in, &out).IsOk()); | |||
| } | |||
| @@ -0,0 +1,336 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing the CutMixBatch op in DE | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| import mindspore.dataset.transforms.c_transforms as data_trans | |||
| import mindspore.dataset.transforms.vision.utils as mode | |||
| from mindspore import log as logger | |||
| from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \ | |||
| config_get_set_num_parallel_workers | |||
| DATA_DIR = "../data/dataset/testCifar10Data" | |||
| GENERATE_GOLDEN = False | |||
| def test_cutmix_batch_success1(plot=False): | |||
| """ | |||
| Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images | |||
| """ | |||
| logger.info("test_cutmix_batch_success1") | |||
| # Original Images | |||
| ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| ds_original = ds_original.batch(5, drop_remainder=True) | |||
| images_original = None | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = image | |||
| else: | |||
| images_original = np.append(images_original, image, axis=0) | |||
| # CutMix Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| hwc2chw_op = vision.HWC2CHW() | |||
| data1 = data1.map(input_columns=["image"], operations=hwc2chw_op) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op) | |||
| cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5) | |||
| data1 = data1.batch(5, drop_remainder=True) | |||
| data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) | |||
| images_cutmix = None | |||
| for idx, (image, _) in enumerate(data1): | |||
| if idx == 0: | |||
| images_cutmix = image.transpose(0, 2, 3, 1) | |||
| else: | |||
| images_cutmix = np.append(images_cutmix, image.transpose(0, 2, 3, 1), axis=0) | |||
| if plot: | |||
| visualize_list(images_original, images_cutmix) | |||
| num_samples = images_original.shape[0] | |||
| mse = np.zeros(num_samples) | |||
| for i in range(num_samples): | |||
| mse[i] = diff_mse(images_cutmix[i], images_original[i]) | |||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||
| def test_cutmix_batch_success2(plot=False): | |||
| """ | |||
| Test CutMixBatch op with default values for alpha and prob on a batch of HWC images | |||
| """ | |||
| logger.info("test_cutmix_batch_success2") | |||
| # Original Images | |||
| ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| ds_original = ds_original.batch(5, drop_remainder=True) | |||
| images_original = None | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = image | |||
| else: | |||
| images_original = np.append(images_original, image, axis=0) | |||
| # CutMix Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op) | |||
| cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) | |||
| data1 = data1.batch(5, drop_remainder=True) | |||
| data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) | |||
| images_cutmix = None | |||
| for idx, (image, _) in enumerate(data1): | |||
| if idx == 0: | |||
| images_cutmix = image | |||
| else: | |||
| images_cutmix = np.append(images_cutmix, image, axis=0) | |||
| if plot: | |||
| visualize_list(images_original, images_cutmix) | |||
| num_samples = images_original.shape[0] | |||
| mse = np.zeros(num_samples) | |||
| for i in range(num_samples): | |||
| mse[i] = diff_mse(images_cutmix[i], images_original[i]) | |||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||
| def test_cutmix_batch_nhwc_md5(): | |||
| """ | |||
| Test CutMixBatch on a batch of HWC images with MD5: | |||
| """ | |||
| logger.info("test_cutmix_batch_nhwc_md5") | |||
| original_seed = config_get_set_seed(0) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # CutMixBatch Images | |||
| data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data = data.map(input_columns=["label"], operations=one_hot_op) | |||
| cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) | |||
| data = data.batch(5, drop_remainder=True) | |||
| data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op) | |||
| filename = "cutmix_batch_c_nhwc_result.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_cutmix_batch_nchw_md5(): | |||
| """ | |||
| Test CutMixBatch on a batch of CHW images with MD5: | |||
| """ | |||
| logger.info("test_cutmix_batch_nchw_md5") | |||
| original_seed = config_get_set_seed(0) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # CutMixBatch Images | |||
| data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| hwc2chw_op = vision.HWC2CHW() | |||
| data = data.map(input_columns=["image"], operations=hwc2chw_op) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data = data.map(input_columns=["label"], operations=one_hot_op) | |||
| cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW) | |||
| data = data.batch(5, drop_remainder=True) | |||
| data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op) | |||
| filename = "cutmix_batch_c_nchw_result.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_cutmix_batch_fail1(): | |||
| """ | |||
| Test CutMixBatch Fail 1 | |||
| We expect this to fail because the images and labels are not batched | |||
| """ | |||
| logger.info("test_cutmix_batch_fail1") | |||
| # CutMixBatch Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op) | |||
| cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) | |||
| with pytest.raises(RuntimeError) as error: | |||
| data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) | |||
| for idx, (image, _) in enumerate(data1): | |||
| if idx == 0: | |||
| images_cutmix = image | |||
| else: | |||
| images_cutmix = np.append(images_cutmix, image, axis=0) | |||
| error_message = "You must batch before calling CutMixBatch" | |||
| assert error_message in str(error.value) | |||
| def test_cutmix_batch_fail2(): | |||
| """ | |||
| Test CutMixBatch Fail 2 | |||
| We expect this to fail because alpha is negative | |||
| """ | |||
| logger.info("test_cutmix_batch_fail2") | |||
| # CutMixBatch Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op) | |||
| with pytest.raises(ValueError) as error: | |||
| vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1) | |||
| error_message = "Input is not within the required interval" | |||
| assert error_message in str(error.value) | |||
| def test_cutmix_batch_fail3(): | |||
| """ | |||
| Test CutMixBatch Fail 2 | |||
| We expect this to fail because prob is larger than 1 | |||
| """ | |||
| logger.info("test_cutmix_batch_fail3") | |||
| # CutMixBatch Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op) | |||
| with pytest.raises(ValueError) as error: | |||
| vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2) | |||
| error_message = "Input is not within the required interval" | |||
| assert error_message in str(error.value) | |||
| def test_cutmix_batch_fail4(): | |||
| """ | |||
| Test CutMixBatch Fail 2 | |||
| We expect this to fail because prob is negative | |||
| """ | |||
| logger.info("test_cutmix_batch_fail4") | |||
| # CutMixBatch Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op) | |||
| with pytest.raises(ValueError) as error: | |||
| vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1) | |||
| error_message = "Input is not within the required interval" | |||
| assert error_message in str(error.value) | |||
| def test_cutmix_batch_fail5(): | |||
| """ | |||
| Test CutMixBatch op | |||
| We expect this to fail because label column is not passed to cutmix_batch | |||
| """ | |||
| logger.info("test_cutmix_batch_fail5") | |||
| # CutMixBatch Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op) | |||
| cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) | |||
| data1 = data1.batch(5, drop_remainder=True) | |||
| data1 = data1.map(input_columns=["image"], operations=cutmix_batch_op) | |||
| with pytest.raises(RuntimeError) as error: | |||
| images_cutmix = np.array([]) | |||
| for idx, (image, _) in enumerate(data1): | |||
| if idx == 0: | |||
| images_cutmix = image | |||
| else: | |||
| images_cutmix = np.append(images_cutmix, image, axis=0) | |||
| error_message = "Both images and labels columns are required" | |||
| assert error_message in str(error.value) | |||
| def test_cutmix_batch_fail6(): | |||
| """ | |||
| Test CutMixBatch op | |||
| We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images | |||
| """ | |||
| logger.info("test_cutmix_batch_fail6") | |||
| # CutMixBatch Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(num_classes=10) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op) | |||
| cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW) | |||
| data1 = data1.batch(5, drop_remainder=True) | |||
| data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) | |||
| with pytest.raises(RuntimeError) as error: | |||
| images_cutmix = np.array([]) | |||
| for idx, (image, _) in enumerate(data1): | |||
| if idx == 0: | |||
| images_cutmix = image | |||
| else: | |||
| images_cutmix = np.append(images_cutmix, image, axis=0) | |||
| error_message = "CutMixBatch: Image doesn't match the given image format." | |||
| assert error_message in str(error.value) | |||
| def test_cutmix_batch_fail7(): | |||
| """ | |||
| Test CutMixBatch op | |||
| We expect this to fail because labels are not in one-hot format | |||
| """ | |||
| logger.info("test_cutmix_batch_fail7") | |||
| # CutMixBatch Images | |||
| data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) | |||
| cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) | |||
| data1 = data1.batch(5, drop_remainder=True) | |||
| data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) | |||
| with pytest.raises(RuntimeError) as error: | |||
| images_cutmix = np.array([]) | |||
| for idx, (image, _) in enumerate(data1): | |||
| if idx == 0: | |||
| images_cutmix = image | |||
| else: | |||
| images_cutmix = np.append(images_cutmix, image, axis=0) | |||
| error_message = "CutMixBatch: Label's must be in one-hot format and in a batch" | |||
| assert error_message in str(error.value) | |||
| if __name__ == "__main__": | |||
| test_cutmix_batch_success1(plot=True) | |||
| test_cutmix_batch_success2(plot=True) | |||
| test_cutmix_batch_nchw_md5() | |||
| test_cutmix_batch_nhwc_md5() | |||
| test_cutmix_batch_fail1() | |||
| test_cutmix_batch_fail2() | |||
| test_cutmix_batch_fail3() | |||
| test_cutmix_batch_fail4() | |||
| test_cutmix_batch_fail5() | |||
| test_cutmix_batch_fail6() | |||
| test_cutmix_batch_fail7() | |||