| @@ -24,7 +24,6 @@ | |||
| #endif | |||
| #include "dataset/kernels/image/cut_out_op.h" | |||
| #include "dataset/kernels/image/decode_op.h" | |||
| #include "dataset/kernels/image/distort_bounding_box_crop_op.h" | |||
| #include "dataset/kernels/image/hwc_to_chw_op.h" | |||
| #include "dataset/kernels/image/image_utils.h" | |||
| #include "dataset/kernels/image/normalize_op.h" | |||
| @@ -369,18 +368,6 @@ void bindTensorOps3(py::module *m) { | |||
| } | |||
| void bindTensorOps4(py::module *m) { | |||
| (void)py::class_<DistortBoundingBoxCropOp, TensorOp, std::shared_ptr<DistortBoundingBoxCropOp>>( | |||
| *m, "DistortBoundingBoxCropOp", | |||
| "Tensor operator to crop an image randomly as long as the cropped image has sufficient " | |||
| "overlap with any one bounding box associated with original image" | |||
| "Takes aspect ratio of the generated crop box, the intersection ratio of crop box and bounding box," | |||
| "crop ratio lower and upper bounds" | |||
| "Optional parameters: number of attempts for crop, number of attempts of crop box generation") | |||
| .def(py::init<float, float, float, float, int32_t, int32_t>(), py::arg("aspect_ratio"), py::arg("intersect_ratio"), | |||
| py::arg("crop_ratio_lower_bound"), py::arg("crop_ratio_upper_bound"), | |||
| py::arg("max_attempts") = DistortBoundingBoxCropOp::kDefMaxAttempts, | |||
| py::arg("box_gen_attempts") = DistortBoundingBoxCropOp::kDefBoxGenAttempts); | |||
| (void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>( | |||
| *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") | |||
| .def(py::init<DataType>(), py::arg("data_type")) | |||
| @@ -3,7 +3,6 @@ if (WIN32) | |||
| center_crop_op.cc | |||
| cut_out_op.cc | |||
| decode_op.cc | |||
| distort_bounding_box_crop_op.cc | |||
| hwc_to_chw_op.cc | |||
| image_utils.cc | |||
| normalize_op.cc | |||
| @@ -27,7 +26,6 @@ else() | |||
| change_mode_op.cc | |||
| cut_out_op.cc | |||
| decode_op.cc | |||
| distort_bounding_box_crop_op.cc | |||
| hwc_to_chw_op.cc | |||
| image_utils.cc | |||
| normalize_op.cc | |||
| @@ -45,4 +43,4 @@ else() | |||
| resize_op.cc | |||
| uniform_aug_op.cc | |||
| ) | |||
| endif() | |||
| endif() | |||
| @@ -33,7 +33,8 @@ const uint8_t CutOutOp::kDefFillB = 0; | |||
| // constructor | |||
| CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r, | |||
| uint8_t fill_g, uint8_t fill_b) | |||
| : box_height_(box_height), | |||
| : rnd_(GetSeed()), | |||
| box_height_(box_height), | |||
| box_width_(box_width), | |||
| num_patches_(num_patches), | |||
| random_color_(random_color), | |||
| @@ -46,8 +47,8 @@ Status CutOutOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T | |||
| IO_CHECK(input, output); | |||
| std::shared_ptr<CVTensor> inputCV = CVTensor::AsCVTensor(input); | |||
| // cut out will clip the erasing area if the box is near the edge of the image and the boxes are black | |||
| RETURN_IF_NOT_OK( | |||
| Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, fill_r_, fill_g_, fill_b_)); | |||
| RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_, | |||
| fill_g_, fill_b_)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -62,6 +62,7 @@ class CutOutOp : public TensorOp { | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| private: | |||
| std::mt19937 rnd_; | |||
| int32_t box_height_; | |||
| int32_t box_width_; | |||
| int32_t num_patches_; | |||
| @@ -1,117 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/image/distort_bounding_box_crop_op.h" | |||
| #include <random> | |||
| #include "dataset/core/cv_tensor.h" | |||
| #include "dataset/kernels/image/image_utils.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const int32_t DistortBoundingBoxCropOp::kDefMaxAttempts = 100; | |||
| const int32_t DistortBoundingBoxCropOp::kDefBoxGenAttempts = 10; | |||
| DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, | |||
| float crop_ratio_ub, int32_t max_attempts, int32_t box_gen_attempts) | |||
| : max_attempts_(max_attempts), | |||
| box_gen_attempts_(box_gen_attempts), | |||
| aspect_ratio_(aspect_ratio), | |||
| intersect_ratio_(intersect_ratio), | |||
| crop_ratio_lb_(crop_ratio_lb), | |||
| crop_ratio_ub_(crop_ratio_ub) { | |||
| seed_ = GetSeed(); | |||
| rnd_.seed(seed_); | |||
| } | |||
| Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||
| std::vector<std::shared_ptr<Tensor>> *output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| if (input.size() != NumInput()) | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input[1]->shape().Size() >= 1, "The shape of the second tensor is abnormal"); | |||
| int64_t num_boxes = 0; | |||
| for (uint64_t i = 1; i < input.size(); i++) { | |||
| if (i == 1) num_boxes = input[i]->shape()[0]; | |||
| if (num_boxes != input[i]->shape()[0]) | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Numbers of boxes do not match"); | |||
| if (input[i]->type() != DataType::DE_FLOAT32) | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "boxes' type is not DE_FLOAT21"); | |||
| } | |||
| // assume input Tensor vector in the order of [img, bbox_y_min, bbox_y_max, bbox_x_min, bbox_x_max] | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of the first tensor is abnormal"); | |||
| int h_in = input[0]->shape()[0]; | |||
| int w_in = input[0]->shape()[1]; | |||
| std::vector<cv::Rect> bounding_boxes; | |||
| for (int64_t i = 0; i < num_boxes; ++i) { | |||
| // bbox coordinates are floats relative to the image width and height | |||
| float y_min, y_max, x_min, x_max; | |||
| RETURN_IF_NOT_OK(input[1]->GetItemAt<float>(&y_min, {i})); | |||
| RETURN_IF_NOT_OK(input[2]->GetItemAt<float>(&y_max, {i})); | |||
| RETURN_IF_NOT_OK(input[3]->GetItemAt<float>(&x_min, {i})); | |||
| RETURN_IF_NOT_OK(input[4]->GetItemAt<float>(&x_max, {i})); | |||
| bounding_boxes.emplace_back(static_cast<int>(x_min * w_in), static_cast<int>(y_min * h_in), | |||
| static_cast<int>((x_max - x_min) * w_in), static_cast<int>((y_max - y_min) * h_in)); | |||
| } | |||
| cv::Rect output_box; | |||
| bool should_crop = false; | |||
| // go over iterations, if no satisfying box found we return the original image | |||
| for (int32_t t = 0; t < max_attempts_; ++t) { | |||
| // try to generate random box | |||
| RETURN_IF_NOT_OK(GenerateRandomCropBox(h_in, w_in, aspect_ratio_, crop_ratio_lb_, crop_ratio_ub_, | |||
| box_gen_attempts_, // int maxIter, should not be needed here | |||
| &output_box, seed_)); | |||
| RETURN_IF_NOT_OK(CheckOverlapConstraint(output_box, | |||
| bounding_boxes, // have to change, should take tensor or add bbox logic | |||
| intersect_ratio_, &should_crop)); | |||
| if (should_crop) { | |||
| // found a box to crop | |||
| break; | |||
| } | |||
| } | |||
| // essentially we have to check this again at the end to return original tensor | |||
| if (should_crop) { | |||
| std::shared_ptr<Tensor> out; | |||
| RETURN_IF_NOT_OK(Crop(input[0], &out, output_box.x, output_box.y, output_box.width, output_box.height)); | |||
| output->push_back(out); | |||
| } else { | |||
| output->push_back(input[0]); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape> &inputs, | |||
| std::vector<TensorShape> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | |||
| outputs.clear(); | |||
| TensorShape out = TensorShape{-1, -1}; | |||
| if (inputs[0].Rank() == 2) outputs.emplace_back(out); | |||
| if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); | |||
| if (!outputs.empty()) return Status::OK(); | |||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | |||
| } | |||
| Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); | |||
| outputs[0] = inputs[0]; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,72 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ | |||
| #define DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ | |||
| #include <memory> | |||
| #include <random> | |||
| #include <vector> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DistortBoundingBoxCropOp : public TensorOp { | |||
| public: | |||
| // Default values, also used by python_bindings.cc | |||
| static const int32_t kDefMaxAttempts; | |||
| static const int32_t kDefBoxGenAttempts; | |||
| // Constructor for DistortBoundingBoxCropOp | |||
| // @param max_attempts tries before the crop happens | |||
| // @param box_gen_attempts crop box generation attempts | |||
| // @param aspect_ratio aspect ratio of the generated crop box | |||
| // @param intersect_ratio area overlap ratio, condition for crop only if area over lap between the generated | |||
| // crop box has sufficient overlap with any 1 bounding box | |||
| // @param crop_ratio_lb the crop ratio lower bound | |||
| // @param crop_ratio_ub the crop ratio upper bound | |||
| // @param seed | |||
| DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, float crop_ratio_ub, | |||
| int32_t max_attempts = kDefMaxAttempts, int32_t box_gen_attempts = kDefBoxGenAttempts); | |||
| ~DistortBoundingBoxCropOp() override = default; | |||
| void Print(std::ostream &out) const override { | |||
| out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; | |||
| } | |||
| Status Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||
| std::vector<std::shared_ptr<Tensor>> *output) override; | |||
| uint32_t NumInput() override { return 5; } | |||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| private: | |||
| int32_t max_attempts_; | |||
| int32_t box_gen_attempts_; | |||
| float aspect_ratio_; | |||
| float intersect_ratio_; | |||
| float crop_ratio_lb_; | |||
| float crop_ratio_ub_; | |||
| std::mt19937 rnd_; | |||
| uint32_t seed_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ | |||
| @@ -636,76 +636,10 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||
| return Status::OK(); | |||
| } | |||
| Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr, | |||
| cv::Rect *crop_box, uint32_t seed) { | |||
| try { | |||
| std::mt19937 rnd; | |||
| rnd.seed(GetSeed()); | |||
| if (input_height <= 0 || input_width <= 0 || ratio <= 0.0 || lb <= 0.0 || lb > ub) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid inputs GenerateRandomCropBox"); | |||
| } | |||
| std::uniform_real_distribution<float> rd_crop_ratio(lb, ub); | |||
| float crop_ratio; | |||
| int crop_width, crop_height; | |||
| bool crop_success = false; | |||
| int64_t input_area = input_height * input_width; | |||
| for (auto i = 0; i < max_itr; i++) { | |||
| crop_ratio = rd_crop_ratio(rnd); | |||
| crop_width = static_cast<int32_t>(std::round(std::sqrt(input_area * static_cast<double>(crop_ratio) / ratio))); | |||
| crop_height = static_cast<int32_t>(std::round(crop_width * ratio)); | |||
| if (crop_width <= input_width && crop_height <= input_height) { | |||
| crop_success = true; | |||
| break; | |||
| } | |||
| } | |||
| if (crop_success == false) { | |||
| ratio = static_cast<float>(input_height) / input_width; | |||
| crop_ratio = rd_crop_ratio(rnd); | |||
| crop_width = static_cast<int>(std::lround(std::sqrt(input_area * static_cast<double>(crop_ratio) / ratio))); | |||
| crop_height = static_cast<int>(std::lround(crop_width * ratio)); | |||
| crop_height = (crop_height > input_height) ? input_height : crop_height; | |||
| crop_width = (crop_width > input_width) ? input_width : crop_width; | |||
| } | |||
| std::uniform_int_distribution<> rd_x(0, input_width - crop_width); | |||
| std::uniform_int_distribution<> rd_y(0, input_height - crop_height); | |||
| *crop_box = cv::Rect(rd_x(rnd), rd_y(rnd), crop_width, crop_height); | |||
| return Status::OK(); | |||
| } catch (const cv::Exception &e) { | |||
| RETURN_STATUS_UNEXPECTED("error in GenerateRandomCropBox."); | |||
| } | |||
| } | |||
| Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Rect> &bounding_boxes, | |||
| float min_intersect_ratio, bool *is_satisfied) { | |||
| try { | |||
| // not satisfied if the crop box contains no pixel | |||
| if (crop_box.area() < 1.0) { | |||
| *is_satisfied = false; | |||
| } | |||
| for (const auto &b_box : bounding_boxes) { | |||
| const float b_box_area = b_box.area(); | |||
| // not satisfied if the bounding box contains no pixel | |||
| if (b_box_area < 1.0) { | |||
| continue; | |||
| } | |||
| const float intersect_ratio = (crop_box & b_box).area() / b_box_area; | |||
| if (intersect_ratio >= min_intersect_ratio) { | |||
| *is_satisfied = true; | |||
| break; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } catch (const cv::Exception &e) { | |||
| RETURN_STATUS_UNEXPECTED("error in CheckOverlapConstraint."); | |||
| } | |||
| } | |||
| Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height, | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r, uint8_t fill_g, | |||
| uint8_t fill_b) { | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, | |||
| uint8_t fill_g, uint8_t fill_b) { | |||
| try { | |||
| std::mt19937 rnd; | |||
| rnd.seed(GetSeed()); | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||
| if (input_cv->mat().data == nullptr || (input_cv->Rank() != 3 && input_cv->shape()[2] != 3)) { | |||
| RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); | |||
| @@ -731,8 +665,8 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp | |||
| // rows in cv mat refers to the height of the cropped box | |||
| // we determine h_start and w_start using two different distributions as erasing is used by two different | |||
| // image augmentations. The bounds are also different in each case. | |||
| int32_t h_start = (bounded) ? height_distribution_bound(rnd) : (height_distribution_unbound(rnd) - box_height); | |||
| int32_t w_start = (bounded) ? width_distribution_bound(rnd) : (width_distribution_unbound(rnd) - box_width); | |||
| int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height); | |||
| int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width); | |||
| int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width; | |||
| int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height; | |||
| @@ -744,9 +678,9 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp | |||
| for (int x = h_start; x < max_height; x++) { | |||
| if (random_color) { | |||
| // fill each box with a random value | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = static_cast<int32_t>(normal_distribution(rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = static_cast<int32_t>(normal_distribution(rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[2] = static_cast<int32_t>(normal_distribution(rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = static_cast<int32_t>(normal_distribution(*rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = static_cast<int32_t>(normal_distribution(*rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[2] = static_cast<int32_t>(normal_distribution(*rnd)); | |||
| } else { | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = fill_r; | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = fill_g; | |||
| @@ -196,12 +196,6 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te | |||
| // @param output: Adjusted image of same shape and type. | |||
| Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &hue); | |||
| Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr, | |||
| cv::Rect *crop_box, uint32_t seed = std::mt19937::default_seed); | |||
| Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Rect> &bounding_boxes, | |||
| float min_intersect_ratio, bool *is_satisfied); | |||
| // Masks out a random section from the image with set dimension | |||
| // @param input: input Tensor | |||
| // @param output: cutOut Tensor | |||
| @@ -214,8 +208,8 @@ Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Re | |||
| // @param fill_g: green fill value for erase | |||
| // @param fill_b: blue fill value for erase. | |||
| Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height, | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r = 0, | |||
| uint8_t fill_g = 0, uint8_t fill_b = 0); | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, | |||
| uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); | |||
| // Pads the input image and puts the padded image in the output | |||
| // @param input: input Tensor | |||