diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc index 64616b8d84..c719e93de7 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc @@ -117,6 +117,7 @@ PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) { .value("DE_INTER_CUBIC", InterpolationMode::kCubic) .value("DE_INTER_AREA", InterpolationMode::kArea) .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) + .value("DE_INTER_PILCUBIC", InterpolationMode::kCubicPil) .export_values(); })); diff --git a/mindspore/ccsrc/minddata/dataset/include/constants.h b/mindspore/ccsrc/minddata/dataset/include/constants.h index 9c6752e087..b8a1a51622 100644 --- a/mindspore/ccsrc/minddata/dataset/include/constants.h +++ b/mindspore/ccsrc/minddata/dataset/include/constants.h @@ -48,7 +48,7 @@ enum class ImageBatchFormat { kNHWC = 0, kNCHW = 1 }; enum class ImageFormat { HWC = 0, CHW = 1, HW = 2 }; // Possible interpolation modes -enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3 }; +enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3, kCubicPil = 4 }; // Possible JiebaMode modes enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 }; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index a16066a818..ffe12d4e16 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(kernels-image OBJECT random_resize_with_bbox_op.cc random_color_op.cc rotate_op.cc + resize_cubic_op.cc ) if(ENABLE_ACL) add_dependencies(kernels-image kernels-soft-dvpp-image kernels-dvpp-image) diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index c2249f47fb..ec32b21173 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -21,11 +21,12 @@ #include #include #include "utils/ms_utils.h" -#include "minddata/dataset/kernels/image/math_utils.h" -#include "minddata/dataset/include/constants.h" #include "minddata/dataset/core/cv_tensor.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/include/constants.h" +#include "minddata/dataset/kernels/image/math_utils.h" +#include "minddata/dataset/kernels/image/resize_cubic_op.h" #include "minddata/dataset/util/random.h" #define MAX_INT_PRECISION 16777216 // float int precision is 16777216 @@ -110,6 +111,19 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out RETURN_STATUS_UNEXPECTED("Resize: input tensor is not in shape of or "); } + if (mode == InterpolationMode::kCubicPil) { + LiteMat imIn, imOut; + std::shared_ptr output_tensor; + TensorShape new_shape = TensorShape({output_height, output_width, 3}); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input_cv->type(), &output_tensor)); + uint8_t *buffer = reinterpret_cast(&(*output_tensor->begin())); + imOut.Init(output_width, output_height, input_cv->shape()[2], reinterpret_cast(buffer), LDataType::UINT8); + imIn.Init(input_cv->shape()[1], input_cv->shape()[0], input_cv->shape()[2], input_cv->mat().data, LDataType::UINT8); + ResizeCubic(imIn, imOut, output_width, output_height); + *output = output_tensor; + return Status::OK(); + } + cv::Mat in_image = input_cv->mat(); // resize image too large or too small if (output_height > in_image.rows * 1000 || output_width > in_image.cols * 1000) { @@ -569,6 +583,24 @@ Status CropAndResize(const std::shared_ptr &input, std::shared_ptrmat(); + + if (mode == InterpolationMode::kCubicPil) { + cv::Mat input_roi = cv_in(roi); + std::shared_ptr input_image; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_roi, &input_image)); + LiteMat imIn, imOut; + std::shared_ptr output_tensor; + TensorShape new_shape = TensorShape({target_height, target_width, 3}); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input_cv->type(), &output_tensor)); + uint8_t *buffer = reinterpret_cast(&(*output_tensor->begin())); + imOut.Init(target_width, target_height, input_cv->shape()[2], reinterpret_cast(buffer), LDataType::UINT8); + imIn.Init(input_image->shape()[1], input_image->shape()[0], input_image->shape()[2], input_image->mat().data, + LDataType::UINT8); + ResizeCubic(imIn, imOut, target_width, target_height); + *output = output_tensor; + return Status::OK(); + } + TensorShape shape{target_height, target_width}; int num_channels = input_cv->shape()[2]; if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_cubic_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_cubic_op.cc new file mode 100644 index 0000000000..4c590aaa8a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_cubic_op.cc @@ -0,0 +1,272 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/resize_cubic_op.h" + +namespace mindspore { +namespace dataset { + +// using 8 bits for result +constexpr uint8_t PrecisionBits = 22; + +// construct lookup table +static const std::vector _clip8_table = []() { + std::vector v1(896, 0); + std::vector v2(384, 255); + for (int i = 0; i < 256; i++) { + v1[i + 640] = i; + } + v1.insert(v1.end(), v2.begin(), v2.end()); + return v1; +}(); + +static const uint8_t *clip8_table = &_clip8_table[640]; + +static inline uint8_t clip8(int input) { return clip8_table[input >> PrecisionBits]; } + +static inline double cubic_interp(double x) { + double a = -0.5; + if (x < 0.0) { + x = -x; + } + if (x < 1.0) { + return ((a + 2.0) * x - (a + 3.0)) * x * x + 1; + } + if (x < 2.0) { + return (((x - 5) * x + 8) * x - 4) * a; + } + return 0.0; +} + +struct interpolation { + double (*interpolation)(double x); + double threshold; +}; + +int calc_coeff(int input_size, int out_size, int input0, int input1, struct interpolation *interp, + std::vector ®ions, std::vector &coeffs_interp) { + double threshold, scale, interp_scale; + int kernel_size; + + scale = static_cast((input1 - input0)) / out_size; + if (scale < 1.0) { + interp_scale = 1.0; + } else { + interp_scale = scale; + } + + // obtain size + threshold = interp->threshold * interp_scale; + + // coefficients number + kernel_size = static_cast(ceil(threshold)) * 2 + 1; + if (out_size > INT_MAX / (kernel_size * static_cast(sizeof(double)))) { + MS_LOG(WARNING) << "Unable to allocator memory as output Image size is so large."; + return 0; + } + + // coefficient array + std::vector coeffs(out_size * kernel_size, 0.0); + std::vector region(out_size * 2, 0); + + for (int xx = 0; xx < out_size; xx++) { + double center = input0 + (xx + 0.5) * scale; + double mm = 0.0, ss = 1.0 / interp_scale; + int x; + // Round for x_min + int x_min = static_cast((center - threshold + 0.5)); + if (x_min < 0) { + x_min = 0; + } + // Round for x_max + int x_max = static_cast((center + threshold + 0.5)); + if (x_max > input_size) { + x_max = input_size; + } + x_max -= x_min; + double *coeff = &coeffs[xx * kernel_size]; + for (x = 0; x < x_max; x++) { + double m = interp->interpolation((x + x_min - center + 0.5) * ss); + coeff[x] = m; + mm += m; + } + for (x = 0; x < x_max; x++) { + if (mm != 0.0) { + coeff[x] /= mm; + } + } + // Remaining values should stay empty if they are used despite of x_max. + for (; x < kernel_size; x++) { + coeff[x] = 0; + } + region[xx * 2 + 0] = x_min; + region[xx * 2 + 1] = x_max; + } + + regions = std::move(region); + coeffs_interp = std::move(coeffs); + return kernel_size; +} + +void normalize_coeff(int out_size, int kernel_size, const std::vector &prekk, std::vector &kk) { + for (int x = 0; x < out_size * kernel_size; x++) { + if (prekk[x] < 0) { + kk[x] = static_cast((-0.5 + prekk[x] * (1 << PrecisionBits))); + } else { + kk[x] = static_cast((0.5 + prekk[x] * (1 << PrecisionBits))); + } + } +} + +Status ImagingHorizontalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size, + const std::vector ®ions, const std::vector &prekk) { + int ss0, ss1, ss2; + int32_t *k; + + // normalize previous calculated coefficients + std::vector kk(prekk.begin(), prekk.end()); + normalize_coeff(output.width_, kernel_size, prekk, kk); + uint8_t *input_ptr = input; + uint8_t *output_ptr = output; + int32_t input_width = input.width_ * 3; + int32_t output_width = output.width_ * 3; + + for (int yy = 0; yy < output.height_; yy++) { + // obtain the ptr of output, and put calculated value into it + uint8_t *bgr_buf = output_ptr; + for (int xx = 0; xx < output.width_; xx++) { + int x_min = regions[xx * 2 + 0]; + int x_max = regions[xx * 2 + 1]; + k = &kk[xx * kernel_size]; + ss0 = ss1 = ss2 = 1 << (PrecisionBits - 1); + for (int x = 0; x < x_max; x++) { + ss0 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 0]) * k[x]; + ss1 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 1]) * k[x]; + ss2 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 2]) * k[x]; + } + bgr_buf[0] = clip8(ss0); + bgr_buf[1] = clip8(ss1); + bgr_buf[2] = clip8(ss2); + bgr_buf += 3; + } + output_ptr += output_width; + } + return Status::OK(); +} + +Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size, + const std::vector ®ions, const std::vector &prekk) { + int ss0, ss1, ss2; + + // normalize previous calculated coefficients + std::vector kk(prekk.begin(), prekk.end()); + normalize_coeff(output.height_, kernel_size, prekk, kk); + uint8_t *input_ptr = input; + uint8_t *output_ptr = output; + const int32_t input_width = input.width_ * 3; + const int32_t output_width = output.width_ * 3; + + for (int yy = 0; yy < output.height_; yy++) { + // obtain the ptr of output, and put calculated value into it + uint8_t *bgr_buf = output_ptr; + int32_t *k = &kk[yy * kernel_size]; + int y_min = regions[yy * 2 + 0]; + int y_max = regions[yy * 2 + 1]; + for (int xx = 0; xx < output.width_; xx++) { + ss0 = ss1 = ss2 = 1 << (PrecisionBits - 1); + for (int y = 0; y < y_max; y++) { + ss0 += (input_ptr[(y + y_min) * input_width + xx * 3 + 0]) * k[y]; + ss1 += (input_ptr[(y + y_min) * input_width + xx * 3 + 1]) * k[y]; + ss2 += (input_ptr[(y + y_min) * input_width + xx * 3 + 2]) * k[y]; + } + bgr_buf[0] = clip8(ss0); + bgr_buf[1] = clip8(ss1); + bgr_buf[2] = clip8(ss2); + bgr_buf += 3; + } + output_ptr += output_width; + } + return Status::OK(); +} + +bool ImageInterpolation(LiteMat input, LiteMat &output, int x_size, int y_size, struct interpolation *interp, + int rect[4]) { + int horizontal_interp, vertical_interp, horiz_kernel, vert_kernel, rect_y0, rect_y1; + std::vector horiz_region, vert_region; + std::vector horiz_coeff, vert_coeff; + LiteMat temp; + + horizontal_interp = x_size != input.width_ || rect[2] != x_size || rect[0]; + vertical_interp = y_size != input.height_ || rect[3] != y_size || rect[1]; + + horiz_kernel = calc_coeff(input.width_, x_size, rect[0], rect[2], interp, horiz_region, horiz_coeff); + if (!horiz_kernel) { + return false; + } + + vert_kernel = calc_coeff(input.height_, y_size, rect[1], rect[3], interp, vert_region, vert_coeff); + if (!vert_kernel) { + return false; + } + + // first and last used row in the input image + rect_y0 = vert_region[0]; + rect_y1 = vert_region[y_size * 2 - 1] + vert_region[y_size * 2 - 2]; + + // two-direction resize, horizontal resize + if (horizontal_interp) { + // Shift region for vertical resize + for (int i = 0; i < y_size; i++) { + vert_region[i * 2] -= rect_y0; + } + temp.Init(x_size, rect_y1 - rect_y0, 3); + + ImagingHorizontalInterp(temp, input, rect_y0, horiz_kernel, horiz_region, horiz_coeff); + if (temp.IsEmpty()) { + return false; + } + output = input = temp; + } + + /* vertical resize */ + if (vertical_interp) { + output.Init(input.width_, y_size, 3); + if (!output.IsEmpty()) { + ImagingVerticalInterp(output, input, 0, vert_kernel, vert_region, vert_coeff); + } + if (output.IsEmpty()) { + return false; + } + } + return true; +} + +bool ResizeCubic(const LiteMat &input, LiteMat &dst, int dst_w, int dst_h) { + if (input.data_type_ != LDataType::UINT8 || input.channel_ != 3) { + MS_LOG(ERROR) << "Unsupported data type, only support input image of uint8 dtype and 3 channel."; + return false; + } + int x_size = dst_w, y_size = dst_h; + int rect[4] = {0, 0, input.width_, input.height_}; + LiteMat output; + + struct interpolation interp = {cubic_interp, 2.0}; + bool res = ImageInterpolation(input, output, x_size, y_size, &interp, rect); + + memcpy_s(dst.data_ptr_, output.size_, output.data_ptr_, output.size_); + return res; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_cubic_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_cubic_op.h new file mode 100644 index 0000000000..86b65074e9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_cubic_op.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "lite_cv/lite_mat.h" +#include "minddata/dataset/util/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +/// \brief Calculate the coefficient for interpolation firstly +int calc_coeff(int input_size, int out_size, int input0, int input1, struct interpolation *interp, + std::vector ®ions, std::vector &coeffs_interp); + +/// \brief Normalize the coefficient for interpolation +void normalize_coeff(int out_size, int kernel_size, const std::vector &prekk, std::vector &kk); + +/// \brief Apply horizontal interpolation on input image +Status ImagingHorizontalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size, + const std::vector ®ions, const std::vector &prekk); + +/// \brief Apply Vertical interpolation on input image +Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size, + const std::vector ®ions, const std::vector &prekk); + +/// \brief Mainly logic of Cubic interpolation +bool ImageInterpolation(LiteMat input, LiteMat &output, int x_size, int y_size, struct interpolation *interp, + int rect[4]); + +/// \brief Apply cubic interpolation on input image and obtain the output image +/// \param[in] input Input image +/// \param[out] dst Output image +/// \param[in] dst_w expected Output image width +/// \param[in] dst_h expected Output image height +bool ResizeCubic(const LiteMat &input, LiteMat &dst, int dst_w, int dst_h); +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_ diff --git a/mindspore/dataset/vision/c_transforms.py b/mindspore/dataset/vision/c_transforms.py index 4bcb7a68ef..bd04d873cb 100644 --- a/mindspore/dataset/vision/c_transforms.py +++ b/mindspore/dataset/vision/c_transforms.py @@ -84,7 +84,8 @@ DE_C_IMAGE_BATCH_FORMAT = {ImageBatchFormat.NHWC: cde.ImageBatchFormat.DE_IMAGE_ DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, Inter.CUBIC: cde.InterpolationMode.DE_INTER_CUBIC, - Inter.AREA: cde.InterpolationMode.DE_INTER_AREA} + Inter.AREA: cde.InterpolationMode.DE_INTER_AREA, + Inter.PILCUBIC: cde.InterpolationMode.DE_INTER_PILCUBIC} def parse_padding(padding): @@ -930,6 +931,10 @@ class RandomResizedCrop(ImageTensorOperation): - Inter.BICUBIC, means interpolation method is bicubic interpolation. + - Inter.AREA, means interpolation method is pixel area interpolation. + + - Inter.PILCUBIC, means interpolation method is bicubic interpolation like implemented in pillow. + max_attempts (int, optional): The maximum number of attempts to propose a valid crop_area (default=10). If exceeded, fall back to use center_crop instead. @@ -1314,6 +1319,8 @@ class Resize(ImageTensorOperation): - Inter.AREA, means interpolation method is pixel area interpolation. + - Inter.PILCUBIC, means interpolation method is bicubic interpolation like implemented in pillow. + Examples: >>> from mindspore.dataset.vision import Inter >>> decode_op = c_vision.Decode() diff --git a/mindspore/dataset/vision/utils.py b/mindspore/dataset/vision/utils.py index 08d0fbb690..a8fc5e2c35 100644 --- a/mindspore/dataset/vision/utils.py +++ b/mindspore/dataset/vision/utils.py @@ -23,6 +23,7 @@ class Inter(IntEnum): BILINEAR = LINEAR = 2 BICUBIC = CUBIC = 3 AREA = 4 + PILCUBIC = 5 # Padding Mode, Border Type diff --git a/tests/ut/cpp/dataset/image_process_test.cc b/tests/ut/cpp/dataset/image_process_test.cc index 8131205057..7038c5a53c 100644 --- a/tests/ut/cpp/dataset/image_process_test.cc +++ b/tests/ut/cpp/dataset/image_process_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include +#include #include "common/common.h" #include "lite_cv/lite_mat.h" #include "lite_cv/image_process.h" -#include -#include - -#include +#include "minddata/dataset/kernels/image/resize_cubic_op.h" using namespace mindspore::dataset; class MindDataImageProcess : public UT::Common { @@ -184,6 +184,24 @@ TEST_F(MindDataImageProcess, test3C) { CompareMat(cv_image, lite_norm_mat_cut); } +TEST_F(MindDataImageProcess, testCubic3C) { + std::string filename = "data/dataset/apple.jpg"; + cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); + cv::Mat rgb_mat; + cv::cvtColor(image, rgb_mat, CV_BGR2RGB); + + LiteMat imIn, imOut; + int32_t output_width = 24; + int32_t output_height = 24; + imIn.Init(rgb_mat.cols, rgb_mat.rows, rgb_mat.channels(), rgb_mat.data, LDataType::UINT8); + imOut.Init(output_width, output_height, 3, LDataType::UINT8); + + bool ret = ResizeCubic(imIn, imOut, output_width, output_height); + + ASSERT_TRUE(ret == true); + return; +} + bool ReadYUV(const char *filename, int w, int h, uint8_t **data) { FILE *f = fopen(filename, "rb"); if (f == nullptr) {