| @@ -134,6 +134,9 @@ std::shared_ptr<TensorOperation> CenterCrop::Parse(const MapTargetDevice &env) { | |||
| return std::make_shared<CenterCropOperation>(data_->size_); | |||
| } | |||
| // RGB2GRAY Transform Operation. | |||
| std::shared_ptr<TensorOperation> RGB2GRAY::Parse() { return std::make_shared<RgbToGrayOperation>(); } | |||
| // Crop Transform Operation. | |||
| struct Crop::Data { | |||
| Data(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size) | |||
| @@ -91,6 +91,22 @@ class CenterCrop : public TensorTransform { | |||
| std::shared_ptr<Data> data_; | |||
| }; | |||
| /// \brief RGB2GRAY TensorTransform. | |||
| /// \notes Convert RGB image or color image to grayscale image | |||
| class RGB2GRAY : public TensorTransform { | |||
| public: | |||
| /// \brief Constructor. | |||
| RGB2GRAY() = default; | |||
| /// \brief Destructor. | |||
| ~RGB2GRAY() = default; | |||
| protected: | |||
| /// \brief Function to convert TensorTransform object into a TensorOperation object. | |||
| /// \return Shared pointer to TensorOperation object. | |||
| std::shared_ptr<TensorOperation> Parse() override; | |||
| }; | |||
| /// \brief Crop TensorTransform. | |||
| /// \notes Crop an image based on location and crop size | |||
| class Crop : public TensorTransform { | |||
| @@ -44,6 +44,7 @@ add_library(kernels-image OBJECT | |||
| random_sharpness_op.cc | |||
| rescale_op.cc | |||
| resize_op.cc | |||
| rgb_to_gray_op.cc | |||
| rgba_to_bgr_op.cc | |||
| rgba_to_rgb_op.cc | |||
| sharpness_op.cc | |||
| @@ -1096,6 +1096,23 @@ Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||
| } | |||
| } | |||
| Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| try { | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input)); | |||
| if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) { | |||
| RETURN_STATUS_UNEXPECTED("RgbToGray: image shape is not <H,W,C> or channel is not 3."); | |||
| } | |||
| TensorShape out_shape = TensorShape({input_cv->shape()[0], input_cv->shape()[1]}); | |||
| std::shared_ptr<CVTensor> output_cv; | |||
| RETURN_IF_NOT_OK(CVTensor::CreateEmpty(out_shape, input_cv->type(), &output_cv)); | |||
| cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast<int>(cv::COLOR_RGB2GRAY)); | |||
| *output = std::static_pointer_cast<Tensor>(output_cv); | |||
| return Status::OK(); | |||
| } catch (const cv::Exception &e) { | |||
| RETURN_STATUS_UNEXPECTED("RgbToGray: " + std::string(e.what())); | |||
| } | |||
| } | |||
| Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height) { | |||
| struct jpeg_decompress_struct cinfo {}; | |||
| struct JpegErrorManagerCustom jerr {}; | |||
| @@ -293,6 +293,12 @@ Status RgbaToRgb(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||
| /// \return Status code | |||
| Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output); | |||
| /// \brief Take in a 3 channel image in RBG to GRAY | |||
| /// \param[in] input The input image | |||
| /// \param[out] output The output image | |||
| /// \return Status code | |||
| Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output); | |||
| /// \brief Get jpeg image width and height | |||
| /// \param input: CVTensor containing the not decoded image 1D bytes | |||
| /// \param img_width: the jpeg image width | |||
| @@ -1672,5 +1672,25 @@ bool GetAffineTransform(std::vector<Point> src_point, std::vector<Point> dst_poi | |||
| return true; | |||
| } | |||
| bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat) { | |||
| if (data_type == LDataType::UINT8) { | |||
| if (mat.IsEmpty()) { | |||
| mat.Init(w, h, 1, LDataType::UINT8); | |||
| } | |||
| unsigned char *ptr = mat; | |||
| const unsigned char *data_ptr = src; | |||
| for (int y = 0; y < h; y++) { | |||
| for (int x = 0; x < w; x++) { | |||
| *ptr = (data_ptr[2] * B2GRAY + data_ptr[1] * G2GRAY + data_ptr[0] * R2GRAY + GRAYSHIFT_DELTA) >> GRAYSHIFT; | |||
| ptr++; | |||
| data_ptr += 3; | |||
| } | |||
| } | |||
| } else { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -137,6 +137,9 @@ bool ConvRowCol(const LiteMat &src, const LiteMat &kx, const LiteMat &ky, LiteMa | |||
| /// \brief Filter the image by a Sobel kernel | |||
| bool Sobel(const LiteMat &src, LiteMat &dst, int flag_x, int flag_y, int ksize, PaddBorderType pad_type); | |||
| /// \brief Convert RGB image or color image to grayscale image | |||
| bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // IMAGE_PROCESS_H_ | |||
| @@ -421,6 +421,40 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||
| return Status::OK(); | |||
| } | |||
| Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| if (input->Rank() != 3) { | |||
| RETURN_STATUS_UNEXPECTED("RgbToGray: input image is not in shape of <H,W,C>"); | |||
| } | |||
| if (input->type() != DataType::DE_UINT8) { | |||
| RETURN_STATUS_UNEXPECTED("RgbToGray: image datatype is not uint8."); | |||
| } | |||
| try { | |||
| int output_height = input->shape()[0]; | |||
| int output_width = input->shape()[1]; | |||
| LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2], | |||
| const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())), | |||
| GetLiteCVDataType(input->type())); | |||
| LiteMat lite_mat_convert; | |||
| std::shared_ptr<Tensor> output_tensor; | |||
| TensorShape new_shape = TensorShape({output_height, output_width, 1}); | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), &output_tensor)); | |||
| uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>())); | |||
| lite_mat_convert.Init(output_width, output_height, 1, reinterpret_cast<void *>(buffer), | |||
| GetLiteCVDataType(input->type())); | |||
| bool ret = | |||
| ConvertRgbToGray(lite_mat_rgb, GetLiteCVDataType(input->type()), output_width, output_height, lite_mat_convert); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ret, "RgbToGray: RGBToGRAY failed."); | |||
| *output = output_tensor; | |||
| } catch (std::runtime_error &e) { | |||
| RETURN_STATUS_UNEXPECTED("RgbToGray: " + std::string(e.what())); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const int32_t &pad_top, | |||
| const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, | |||
| uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { | |||
| @@ -95,6 +95,12 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||
| int32_t output_width, double fx = 0.0, double fy = 0.0, | |||
| InterpolationMode mode = InterpolationMode::kLinear); | |||
| /// \brief Take in a 3 channel image in RBG to GRAY | |||
| /// \param[in] input The input image | |||
| /// \param[out] output The output image | |||
| /// \return Status code | |||
| Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output); | |||
| /// \brief Pads the input image and puts the padded image in the output | |||
| /// \param[in] input: input Tensor | |||
| /// \param[out] output: padded Tensor | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/kernels/image/rgb_to_gray_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/kernels/image/image_utils.h" | |||
| #else | |||
| #include "minddata/dataset/kernels/image/lite_image_utils.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status RgbToGrayOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| return RgbToGray(input, output); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class RgbToGrayOp : public TensorOp { | |||
| public: | |||
| RgbToGrayOp() = default; | |||
| ~RgbToGrayOp() override = default; | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| std::string Name() const override { return kRgbToGrayOp; } | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_ | |||
| @@ -72,6 +72,7 @@ | |||
| #include "minddata/dataset/kernels/image/rgba_to_bgr_op.h" | |||
| #include "minddata/dataset/kernels/image/rgba_to_rgb_op.h" | |||
| #endif | |||
| #include "minddata/dataset/kernels/image/rgb_to_gray_op.h" | |||
| #include "minddata/dataset/kernels/image/rotate_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_random_crop_resize_jpeg_op.h" | |||
| @@ -232,6 +233,11 @@ Status CenterCropOperation::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| // RGB2GRAYOperation | |||
| Status RgbToGrayOperation::ValidateParams() { return Status::OK(); } | |||
| std::shared_ptr<TensorOp> RgbToGrayOperation::Build() { return std::make_shared<RgbToGrayOp>(); } | |||
| // CropOperation. | |||
| CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size) | |||
| : coordinates_(coordinates), size_(size) {} | |||
| @@ -74,6 +74,7 @@ constexpr char kResizeOperation[] = "Resize"; | |||
| constexpr char kResizeWithBBoxOperation[] = "ResizeWithBBox"; | |||
| constexpr char kRgbaToBgrOperation[] = "RgbaToBgr"; | |||
| constexpr char kRgbaToRgbOperation[] = "RgbaToRgb"; | |||
| constexpr char kRgbToGrayOperation[] = "RgbToGray"; | |||
| constexpr char kRotateOperation[] = "Rotate"; | |||
| constexpr char kSoftDvppDecodeRandomCropResizeJpegOperation[] = "SoftDvppDecodeRandomCropResizeJpeg"; | |||
| constexpr char kSoftDvppDecodeResizeJpegOperation[] = "SoftDvppDecodeResizeJpeg"; | |||
| @@ -163,6 +164,19 @@ class CenterCropOperation : public TensorOperation { | |||
| std::vector<int32_t> size_; | |||
| }; | |||
| class RgbToGrayOperation : public TensorOperation { | |||
| public: | |||
| RgbToGrayOperation() = default; | |||
| ~RgbToGrayOperation() = default; | |||
| std::shared_ptr<TensorOp> Build() override; | |||
| Status ValidateParams() override; | |||
| std::string Name() const override { return kRgbToGrayOperation; } | |||
| }; | |||
| class CropOperation : public TensorOperation { | |||
| public: | |||
| CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size); | |||
| @@ -97,6 +97,7 @@ constexpr char kResizeOp[] = "ResizeOp"; | |||
| constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; | |||
| constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp"; | |||
| constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp"; | |||
| constexpr char kRgbToGrayOp[] = "RgbToGrayOp"; | |||
| constexpr char kSharpnessOp[] = "SharpnessOp"; | |||
| constexpr char kSolarizeOp[] = "SolarizeOp"; | |||
| constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; | |||
| @@ -88,6 +88,22 @@ class CenterCrop : public TensorTransform { | |||
| std::shared_ptr<Data> data_; | |||
| }; | |||
| /// \brief RGB2GRAY TensorTransform. | |||
| /// \notes Convert RGB image or color image to grayscale image | |||
| class RGB2GRAY : public TensorTransform { | |||
| public: | |||
| /// \brief Constructor. | |||
| RGB2GRAY() = default; | |||
| /// \brief Destructor. | |||
| ~RGB2GRAY() = default; | |||
| protected: | |||
| /// \brief Function to convert TensorTransform object into a TensorOperation object. | |||
| /// \return Shared pointer to TensorOperation object. | |||
| std::shared_ptr<TensorOperation> Parse() override; | |||
| }; | |||
| /// \brief Crop TensorTransform. | |||
| /// \notes Crop an image based on location and crop size | |||
| class Crop : public TensorTransform { | |||
| @@ -191,13 +191,14 @@ if(BUILD_MINDDATA STREQUAL "full") | |||
| ${MINDDATA_DIR}/util/cond_var.cc | |||
| ${MINDDATA_DIR}/engine/data_schema.cc | |||
| ${MINDDATA_DIR}/kernels/tensor_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/affine_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/lite_image_utils.cc | |||
| ${MINDDATA_DIR}/kernels/image/center_crop_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/crop_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/decode_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/normalize_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/affine_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/resize_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/rotate_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/random_affine_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/math_utils.cc | |||
| @@ -279,6 +280,7 @@ elseif(BUILD_MINDDATA STREQUAL "wrapper") | |||
| ${MINDDATA_DIR}/kernels/image/crop_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/normalize_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/resize_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc | |||
| ${MINDDATA_DIR}/kernels/image/rotate_op.cc | |||
| ${MINDDATA_DIR}/kernels/data/compose_op.cc | |||
| ${MINDDATA_DIR}/kernels/data/duplicate_op.cc | |||
| @@ -377,6 +379,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite") | |||
| "${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc" | |||
| "${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc" | |||
| "${MINDDATA_DIR}/kernels/image/rescale_op.cc" | |||
| "${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc" | |||
| "${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc" | |||
| "${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc" | |||
| "${MINDDATA_DIR}/kernels/image/sharpness_op.cc" | |||
| @@ -195,3 +195,39 @@ TEST_F(MindDataTestPipeline, TestResizeWithBBoxSuccess) { | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestRGB2GRAYSucess) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRGB2GRAYSucess."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<SequentialSampler>(0, 1)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorTransform> convert(new mindspore::dataset::vision::RGB2GRAY()); | |||
| ds = ds->Map({convert}); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| EXPECT_EQ(i, 1); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "lite_cv/lite_mat.h" | |||
| #include "lite_cv/image_process.h" | |||
| @@ -1714,3 +1714,25 @@ TEST_F(MindDataImageProcess, TestSobelFlag) { | |||
| distance_x = sqrt(distance_x / total_size); | |||
| EXPECT_EQ(distance_x, 0.0f); | |||
| } | |||
| TEST_F(MindDataImageProcess, testConvertRgbToGray) { | |||
| std::string filename = "data/dataset/apple.jpg"; | |||
| cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); | |||
| cv::Mat rgb_mat; | |||
| cv::Mat rgb_mat1; | |||
| cv::cvtColor(image, rgb_mat, CV_BGR2GRAY); | |||
| cv::imwrite("./opencv_image.jpg", rgb_mat); | |||
| cv::cvtColor(image, rgb_mat1, CV_BGR2RGB); | |||
| LiteMat lite_mat_rgb; | |||
| lite_mat_rgb.Init(rgb_mat1.cols, rgb_mat1.rows, rgb_mat1.channels(), rgb_mat1.data, LDataType::UINT8); | |||
| LiteMat lite_mat_gray; | |||
| bool ret = ConvertRgbToGray(lite_mat_rgb, LDataType::UINT8, image.cols, image.rows, lite_mat_gray); | |||
| ASSERT_TRUE(ret == true); | |||
| cv::Mat dst_image(lite_mat_gray.height_, lite_mat_gray.width_, CV_8UC1, lite_mat_gray.data_ptr_); | |||
| cv::imwrite("./mindspore_image.jpg", dst_image); | |||
| CompareMat(rgb_mat, lite_mat_gray); | |||
| } | |||