| @@ -117,6 +117,7 @@ PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) { | |||||
| .value("DE_INTER_CUBIC", InterpolationMode::kCubic) | .value("DE_INTER_CUBIC", InterpolationMode::kCubic) | ||||
| .value("DE_INTER_AREA", InterpolationMode::kArea) | .value("DE_INTER_AREA", InterpolationMode::kArea) | ||||
| .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) | .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) | ||||
| .value("DE_INTER_PILCUBIC", InterpolationMode::kCubicPil) | |||||
| .export_values(); | .export_values(); | ||||
| })); | })); | ||||
| @@ -48,7 +48,7 @@ enum class ImageBatchFormat { kNHWC = 0, kNCHW = 1 }; | |||||
| enum class ImageFormat { HWC = 0, CHW = 1, HW = 2 }; | enum class ImageFormat { HWC = 0, CHW = 1, HW = 2 }; | ||||
| // Possible interpolation modes | // Possible interpolation modes | ||||
| enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3 }; | |||||
| enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3, kCubicPil = 4 }; | |||||
| // Possible JiebaMode modes | // Possible JiebaMode modes | ||||
| enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 }; | enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 }; | ||||
| @@ -56,6 +56,7 @@ add_library(kernels-image OBJECT | |||||
| random_resize_with_bbox_op.cc | random_resize_with_bbox_op.cc | ||||
| random_color_op.cc | random_color_op.cc | ||||
| rotate_op.cc | rotate_op.cc | ||||
| resize_cubic_op.cc | |||||
| ) | ) | ||||
| if(ENABLE_ACL) | if(ENABLE_ACL) | ||||
| add_dependencies(kernels-image kernels-soft-dvpp-image kernels-dvpp-image) | add_dependencies(kernels-image kernels-soft-dvpp-image kernels-dvpp-image) | ||||
| @@ -21,11 +21,12 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <opencv2/imgcodecs.hpp> | #include <opencv2/imgcodecs.hpp> | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "minddata/dataset/kernels/image/math_utils.h" | |||||
| #include "minddata/dataset/include/constants.h" | |||||
| #include "minddata/dataset/core/cv_tensor.h" | #include "minddata/dataset/core/cv_tensor.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/core/tensor_shape.h" | #include "minddata/dataset/core/tensor_shape.h" | ||||
| #include "minddata/dataset/include/constants.h" | |||||
| #include "minddata/dataset/kernels/image/math_utils.h" | |||||
| #include "minddata/dataset/kernels/image/resize_cubic_op.h" | |||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| #define MAX_INT_PRECISION 16777216 // float int precision is 16777216 | #define MAX_INT_PRECISION 16777216 // float int precision is 16777216 | ||||
| @@ -110,6 +111,19 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||||
| RETURN_STATUS_UNEXPECTED("Resize: input tensor is not in shape of <H,W,C> or <H,W>"); | RETURN_STATUS_UNEXPECTED("Resize: input tensor is not in shape of <H,W,C> or <H,W>"); | ||||
| } | } | ||||
| if (mode == InterpolationMode::kCubicPil) { | |||||
| LiteMat imIn, imOut; | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| TensorShape new_shape = TensorShape({output_height, output_width, 3}); | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input_cv->type(), &output_tensor)); | |||||
| uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>())); | |||||
| imOut.Init(output_width, output_height, input_cv->shape()[2], reinterpret_cast<void *>(buffer), LDataType::UINT8); | |||||
| imIn.Init(input_cv->shape()[1], input_cv->shape()[0], input_cv->shape()[2], input_cv->mat().data, LDataType::UINT8); | |||||
| ResizeCubic(imIn, imOut, output_width, output_height); | |||||
| *output = output_tensor; | |||||
| return Status::OK(); | |||||
| } | |||||
| cv::Mat in_image = input_cv->mat(); | cv::Mat in_image = input_cv->mat(); | ||||
| // resize image too large or too small | // resize image too large or too small | ||||
| if (output_height > in_image.rows * 1000 || output_width > in_image.cols * 1000) { | if (output_height > in_image.rows * 1000 || output_width > in_image.cols * 1000) { | ||||
| @@ -569,6 +583,24 @@ Status CropAndResize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso | |||||
| cv::Rect roi(x, y, crop_width, crop_height); | cv::Rect roi(x, y, crop_width, crop_height); | ||||
| auto cv_mode = GetCVInterpolationMode(mode); | auto cv_mode = GetCVInterpolationMode(mode); | ||||
| cv::Mat cv_in = input_cv->mat(); | cv::Mat cv_in = input_cv->mat(); | ||||
| if (mode == InterpolationMode::kCubicPil) { | |||||
| cv::Mat input_roi = cv_in(roi); | |||||
| std::shared_ptr<CVTensor> input_image; | |||||
| RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_roi, &input_image)); | |||||
| LiteMat imIn, imOut; | |||||
| std::shared_ptr<Tensor> output_tensor; | |||||
| TensorShape new_shape = TensorShape({target_height, target_width, 3}); | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input_cv->type(), &output_tensor)); | |||||
| uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>())); | |||||
| imOut.Init(target_width, target_height, input_cv->shape()[2], reinterpret_cast<void *>(buffer), LDataType::UINT8); | |||||
| imIn.Init(input_image->shape()[1], input_image->shape()[0], input_image->shape()[2], input_image->mat().data, | |||||
| LDataType::UINT8); | |||||
| ResizeCubic(imIn, imOut, target_width, target_height); | |||||
| *output = output_tensor; | |||||
| return Status::OK(); | |||||
| } | |||||
| TensorShape shape{target_height, target_width}; | TensorShape shape{target_height, target_width}; | ||||
| int num_channels = input_cv->shape()[2]; | int num_channels = input_cv->shape()[2]; | ||||
| if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); | if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); | ||||
| @@ -0,0 +1,272 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/kernels/image/resize_cubic_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // using 8 bits for result | |||||
| constexpr uint8_t PrecisionBits = 22; | |||||
| // construct lookup table | |||||
| static const std::vector<uint8_t> _clip8_table = []() { | |||||
| std::vector<uint8_t> v1(896, 0); | |||||
| std::vector<uint8_t> v2(384, 255); | |||||
| for (int i = 0; i < 256; i++) { | |||||
| v1[i + 640] = i; | |||||
| } | |||||
| v1.insert(v1.end(), v2.begin(), v2.end()); | |||||
| return v1; | |||||
| }(); | |||||
| static const uint8_t *clip8_table = &_clip8_table[640]; | |||||
| static inline uint8_t clip8(int input) { return clip8_table[input >> PrecisionBits]; } | |||||
| static inline double cubic_interp(double x) { | |||||
| double a = -0.5; | |||||
| if (x < 0.0) { | |||||
| x = -x; | |||||
| } | |||||
| if (x < 1.0) { | |||||
| return ((a + 2.0) * x - (a + 3.0)) * x * x + 1; | |||||
| } | |||||
| if (x < 2.0) { | |||||
| return (((x - 5) * x + 8) * x - 4) * a; | |||||
| } | |||||
| return 0.0; | |||||
| } | |||||
| struct interpolation { | |||||
| double (*interpolation)(double x); | |||||
| double threshold; | |||||
| }; | |||||
| int calc_coeff(int input_size, int out_size, int input0, int input1, struct interpolation *interp, | |||||
| std::vector<int> ®ions, std::vector<double> &coeffs_interp) { | |||||
| double threshold, scale, interp_scale; | |||||
| int kernel_size; | |||||
| scale = static_cast<double>((input1 - input0)) / out_size; | |||||
| if (scale < 1.0) { | |||||
| interp_scale = 1.0; | |||||
| } else { | |||||
| interp_scale = scale; | |||||
| } | |||||
| // obtain size | |||||
| threshold = interp->threshold * interp_scale; | |||||
| // coefficients number | |||||
| kernel_size = static_cast<int>(ceil(threshold)) * 2 + 1; | |||||
| if (out_size > INT_MAX / (kernel_size * static_cast<int>(sizeof(double)))) { | |||||
| MS_LOG(WARNING) << "Unable to allocator memory as output Image size is so large."; | |||||
| return 0; | |||||
| } | |||||
| // coefficient array | |||||
| std::vector<double> coeffs(out_size * kernel_size, 0.0); | |||||
| std::vector<int> region(out_size * 2, 0); | |||||
| for (int xx = 0; xx < out_size; xx++) { | |||||
| double center = input0 + (xx + 0.5) * scale; | |||||
| double mm = 0.0, ss = 1.0 / interp_scale; | |||||
| int x; | |||||
| // Round for x_min | |||||
| int x_min = static_cast<int>((center - threshold + 0.5)); | |||||
| if (x_min < 0) { | |||||
| x_min = 0; | |||||
| } | |||||
| // Round for x_max | |||||
| int x_max = static_cast<int>((center + threshold + 0.5)); | |||||
| if (x_max > input_size) { | |||||
| x_max = input_size; | |||||
| } | |||||
| x_max -= x_min; | |||||
| double *coeff = &coeffs[xx * kernel_size]; | |||||
| for (x = 0; x < x_max; x++) { | |||||
| double m = interp->interpolation((x + x_min - center + 0.5) * ss); | |||||
| coeff[x] = m; | |||||
| mm += m; | |||||
| } | |||||
| for (x = 0; x < x_max; x++) { | |||||
| if (mm != 0.0) { | |||||
| coeff[x] /= mm; | |||||
| } | |||||
| } | |||||
| // Remaining values should stay empty if they are used despite of x_max. | |||||
| for (; x < kernel_size; x++) { | |||||
| coeff[x] = 0; | |||||
| } | |||||
| region[xx * 2 + 0] = x_min; | |||||
| region[xx * 2 + 1] = x_max; | |||||
| } | |||||
| regions = std::move(region); | |||||
| coeffs_interp = std::move(coeffs); | |||||
| return kernel_size; | |||||
| } | |||||
| void normalize_coeff(int out_size, int kernel_size, const std::vector<double> &prekk, std::vector<int> &kk) { | |||||
| for (int x = 0; x < out_size * kernel_size; x++) { | |||||
| if (prekk[x] < 0) { | |||||
| kk[x] = static_cast<int>((-0.5 + prekk[x] * (1 << PrecisionBits))); | |||||
| } else { | |||||
| kk[x] = static_cast<int>((0.5 + prekk[x] * (1 << PrecisionBits))); | |||||
| } | |||||
| } | |||||
| } | |||||
| Status ImagingHorizontalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size, | |||||
| const std::vector<int> ®ions, const std::vector<double> &prekk) { | |||||
| int ss0, ss1, ss2; | |||||
| int32_t *k; | |||||
| // normalize previous calculated coefficients | |||||
| std::vector<int> kk(prekk.begin(), prekk.end()); | |||||
| normalize_coeff(output.width_, kernel_size, prekk, kk); | |||||
| uint8_t *input_ptr = input; | |||||
| uint8_t *output_ptr = output; | |||||
| int32_t input_width = input.width_ * 3; | |||||
| int32_t output_width = output.width_ * 3; | |||||
| for (int yy = 0; yy < output.height_; yy++) { | |||||
| // obtain the ptr of output, and put calculated value into it | |||||
| uint8_t *bgr_buf = output_ptr; | |||||
| for (int xx = 0; xx < output.width_; xx++) { | |||||
| int x_min = regions[xx * 2 + 0]; | |||||
| int x_max = regions[xx * 2 + 1]; | |||||
| k = &kk[xx * kernel_size]; | |||||
| ss0 = ss1 = ss2 = 1 << (PrecisionBits - 1); | |||||
| for (int x = 0; x < x_max; x++) { | |||||
| ss0 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 0]) * k[x]; | |||||
| ss1 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 1]) * k[x]; | |||||
| ss2 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 2]) * k[x]; | |||||
| } | |||||
| bgr_buf[0] = clip8(ss0); | |||||
| bgr_buf[1] = clip8(ss1); | |||||
| bgr_buf[2] = clip8(ss2); | |||||
| bgr_buf += 3; | |||||
| } | |||||
| output_ptr += output_width; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size, | |||||
| const std::vector<int> ®ions, const std::vector<double> &prekk) { | |||||
| int ss0, ss1, ss2; | |||||
| // normalize previous calculated coefficients | |||||
| std::vector<int> kk(prekk.begin(), prekk.end()); | |||||
| normalize_coeff(output.height_, kernel_size, prekk, kk); | |||||
| uint8_t *input_ptr = input; | |||||
| uint8_t *output_ptr = output; | |||||
| const int32_t input_width = input.width_ * 3; | |||||
| const int32_t output_width = output.width_ * 3; | |||||
| for (int yy = 0; yy < output.height_; yy++) { | |||||
| // obtain the ptr of output, and put calculated value into it | |||||
| uint8_t *bgr_buf = output_ptr; | |||||
| int32_t *k = &kk[yy * kernel_size]; | |||||
| int y_min = regions[yy * 2 + 0]; | |||||
| int y_max = regions[yy * 2 + 1]; | |||||
| for (int xx = 0; xx < output.width_; xx++) { | |||||
| ss0 = ss1 = ss2 = 1 << (PrecisionBits - 1); | |||||
| for (int y = 0; y < y_max; y++) { | |||||
| ss0 += (input_ptr[(y + y_min) * input_width + xx * 3 + 0]) * k[y]; | |||||
| ss1 += (input_ptr[(y + y_min) * input_width + xx * 3 + 1]) * k[y]; | |||||
| ss2 += (input_ptr[(y + y_min) * input_width + xx * 3 + 2]) * k[y]; | |||||
| } | |||||
| bgr_buf[0] = clip8(ss0); | |||||
| bgr_buf[1] = clip8(ss1); | |||||
| bgr_buf[2] = clip8(ss2); | |||||
| bgr_buf += 3; | |||||
| } | |||||
| output_ptr += output_width; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| bool ImageInterpolation(LiteMat input, LiteMat &output, int x_size, int y_size, struct interpolation *interp, | |||||
| int rect[4]) { | |||||
| int horizontal_interp, vertical_interp, horiz_kernel, vert_kernel, rect_y0, rect_y1; | |||||
| std::vector<int> horiz_region, vert_region; | |||||
| std::vector<double> horiz_coeff, vert_coeff; | |||||
| LiteMat temp; | |||||
| horizontal_interp = x_size != input.width_ || rect[2] != x_size || rect[0]; | |||||
| vertical_interp = y_size != input.height_ || rect[3] != y_size || rect[1]; | |||||
| horiz_kernel = calc_coeff(input.width_, x_size, rect[0], rect[2], interp, horiz_region, horiz_coeff); | |||||
| if (!horiz_kernel) { | |||||
| return false; | |||||
| } | |||||
| vert_kernel = calc_coeff(input.height_, y_size, rect[1], rect[3], interp, vert_region, vert_coeff); | |||||
| if (!vert_kernel) { | |||||
| return false; | |||||
| } | |||||
| // first and last used row in the input image | |||||
| rect_y0 = vert_region[0]; | |||||
| rect_y1 = vert_region[y_size * 2 - 1] + vert_region[y_size * 2 - 2]; | |||||
| // two-direction resize, horizontal resize | |||||
| if (horizontal_interp) { | |||||
| // Shift region for vertical resize | |||||
| for (int i = 0; i < y_size; i++) { | |||||
| vert_region[i * 2] -= rect_y0; | |||||
| } | |||||
| temp.Init(x_size, rect_y1 - rect_y0, 3); | |||||
| ImagingHorizontalInterp(temp, input, rect_y0, horiz_kernel, horiz_region, horiz_coeff); | |||||
| if (temp.IsEmpty()) { | |||||
| return false; | |||||
| } | |||||
| output = input = temp; | |||||
| } | |||||
| /* vertical resize */ | |||||
| if (vertical_interp) { | |||||
| output.Init(input.width_, y_size, 3); | |||||
| if (!output.IsEmpty()) { | |||||
| ImagingVerticalInterp(output, input, 0, vert_kernel, vert_region, vert_coeff); | |||||
| } | |||||
| if (output.IsEmpty()) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ResizeCubic(const LiteMat &input, LiteMat &dst, int dst_w, int dst_h) { | |||||
| if (input.data_type_ != LDataType::UINT8 || input.channel_ != 3) { | |||||
| MS_LOG(ERROR) << "Unsupported data type, only support input image of uint8 dtype and 3 channel."; | |||||
| return false; | |||||
| } | |||||
| int x_size = dst_w, y_size = dst_h; | |||||
| int rect[4] = {0, 0, input.width_, input.height_}; | |||||
| LiteMat output; | |||||
| struct interpolation interp = {cubic_interp, 2.0}; | |||||
| bool res = ImageInterpolation(input, output, x_size, y_size, &interp, rect); | |||||
| memcpy_s(dst.data_ptr_, output.size_, output.data_ptr_, output.size_); | |||||
| return res; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_ | |||||
| #include <float.h> | |||||
| #include <math.h> | |||||
| #include <limits.h> | |||||
| #include <string.h> | |||||
| #include <cmath> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <random> | |||||
| #include "lite_cv/lite_mat.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief Calculate the coefficient for interpolation firstly | |||||
| int calc_coeff(int input_size, int out_size, int input0, int input1, struct interpolation *interp, | |||||
| std::vector<int> ®ions, std::vector<double> &coeffs_interp); | |||||
| /// \brief Normalize the coefficient for interpolation | |||||
| void normalize_coeff(int out_size, int kernel_size, const std::vector<double> &prekk, std::vector<int> &kk); | |||||
| /// \brief Apply horizontal interpolation on input image | |||||
| Status ImagingHorizontalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size, | |||||
| const std::vector<int> ®ions, const std::vector<double> &prekk); | |||||
| /// \brief Apply Vertical interpolation on input image | |||||
| Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size, | |||||
| const std::vector<int> ®ions, const std::vector<double> &prekk); | |||||
| /// \brief Mainly logic of Cubic interpolation | |||||
| bool ImageInterpolation(LiteMat input, LiteMat &output, int x_size, int y_size, struct interpolation *interp, | |||||
| int rect[4]); | |||||
| /// \brief Apply cubic interpolation on input image and obtain the output image | |||||
| /// \param[in] input Input image | |||||
| /// \param[out] dst Output image | |||||
| /// \param[in] dst_w expected Output image width | |||||
| /// \param[in] dst_h expected Output image height | |||||
| bool ResizeCubic(const LiteMat &input, LiteMat &dst, int dst_w, int dst_h); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_CUBIC_OP_H_ | |||||
| @@ -84,7 +84,8 @@ DE_C_IMAGE_BATCH_FORMAT = {ImageBatchFormat.NHWC: cde.ImageBatchFormat.DE_IMAGE_ | |||||
| DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, | DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, | ||||
| Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, | Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, | ||||
| Inter.CUBIC: cde.InterpolationMode.DE_INTER_CUBIC, | Inter.CUBIC: cde.InterpolationMode.DE_INTER_CUBIC, | ||||
| Inter.AREA: cde.InterpolationMode.DE_INTER_AREA} | |||||
| Inter.AREA: cde.InterpolationMode.DE_INTER_AREA, | |||||
| Inter.PILCUBIC: cde.InterpolationMode.DE_INTER_PILCUBIC} | |||||
| def parse_padding(padding): | def parse_padding(padding): | ||||
| @@ -930,6 +931,10 @@ class RandomResizedCrop(ImageTensorOperation): | |||||
| - Inter.BICUBIC, means interpolation method is bicubic interpolation. | - Inter.BICUBIC, means interpolation method is bicubic interpolation. | ||||
| - Inter.AREA, means interpolation method is pixel area interpolation. | |||||
| - Inter.PILCUBIC, means interpolation method is bicubic interpolation like implemented in pillow. | |||||
| max_attempts (int, optional): The maximum number of attempts to propose a valid | max_attempts (int, optional): The maximum number of attempts to propose a valid | ||||
| crop_area (default=10). If exceeded, fall back to use center_crop instead. | crop_area (default=10). If exceeded, fall back to use center_crop instead. | ||||
| @@ -1314,6 +1319,8 @@ class Resize(ImageTensorOperation): | |||||
| - Inter.AREA, means interpolation method is pixel area interpolation. | - Inter.AREA, means interpolation method is pixel area interpolation. | ||||
| - Inter.PILCUBIC, means interpolation method is bicubic interpolation like implemented in pillow. | |||||
| Examples: | Examples: | ||||
| >>> from mindspore.dataset.vision import Inter | >>> from mindspore.dataset.vision import Inter | ||||
| >>> decode_op = c_vision.Decode() | >>> decode_op = c_vision.Decode() | ||||
| @@ -23,6 +23,7 @@ class Inter(IntEnum): | |||||
| BILINEAR = LINEAR = 2 | BILINEAR = LINEAR = 2 | ||||
| BICUBIC = CUBIC = 3 | BICUBIC = CUBIC = 3 | ||||
| AREA = 4 | AREA = 4 | ||||
| PILCUBIC = 5 | |||||
| # Padding Mode, Border Type | # Padding Mode, Border Type | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -13,14 +13,14 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <opencv2/opencv.hpp> | |||||
| #include <opencv2/imgproc/types_c.h> | |||||
| #include <fstream> | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "lite_cv/lite_mat.h" | #include "lite_cv/lite_mat.h" | ||||
| #include "lite_cv/image_process.h" | #include "lite_cv/image_process.h" | ||||
| #include <opencv2/opencv.hpp> | |||||
| #include <opencv2/imgproc/types_c.h> | |||||
| #include <fstream> | |||||
| #include "minddata/dataset/kernels/image/resize_cubic_op.h" | |||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| class MindDataImageProcess : public UT::Common { | class MindDataImageProcess : public UT::Common { | ||||
| @@ -184,6 +184,24 @@ TEST_F(MindDataImageProcess, test3C) { | |||||
| CompareMat(cv_image, lite_norm_mat_cut); | CompareMat(cv_image, lite_norm_mat_cut); | ||||
| } | } | ||||
| TEST_F(MindDataImageProcess, testCubic3C) { | |||||
| std::string filename = "data/dataset/apple.jpg"; | |||||
| cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); | |||||
| cv::Mat rgb_mat; | |||||
| cv::cvtColor(image, rgb_mat, CV_BGR2RGB); | |||||
| LiteMat imIn, imOut; | |||||
| int32_t output_width = 24; | |||||
| int32_t output_height = 24; | |||||
| imIn.Init(rgb_mat.cols, rgb_mat.rows, rgb_mat.channels(), rgb_mat.data, LDataType::UINT8); | |||||
| imOut.Init(output_width, output_height, 3, LDataType::UINT8); | |||||
| bool ret = ResizeCubic(imIn, imOut, output_width, output_height); | |||||
| ASSERT_TRUE(ret == true); | |||||
| return; | |||||
| } | |||||
| bool ReadYUV(const char *filename, int w, int h, uint8_t **data) { | bool ReadYUV(const char *filename, int w, int h, uint8_t **data) { | ||||
| FILE *f = fopen(filename, "rb"); | FILE *f = fopen(filename, "rb"); | ||||
| if (f == nullptr) { | if (f == nullptr) { | ||||