Browse Source

Add RGB2GRAY operation

tags/v1.2.0-rc1
shenwei41 4 years ago
parent
commit
7b56d1772e
18 changed files with 280 additions and 2 deletions
  1. +3
    -0
      mindspore/ccsrc/minddata/dataset/api/vision.cc
  2. +16
    -0
      mindspore/ccsrc/minddata/dataset/include/vision_lite.h
  3. +1
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt
  4. +17
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc
  5. +6
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h
  6. +20
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.cc
  7. +3
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.h
  8. +34
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc
  9. +6
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h
  10. +32
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.cc
  11. +42
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.h
  12. +6
    -0
      mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc
  13. +14
    -0
      mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h
  14. +1
    -0
      mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h
  15. +16
    -0
      mindspore/ccsrc/minddata/dataset/liteapi/include/vision_lite.h
  16. +4
    -1
      mindspore/lite/minddata/CMakeLists.txt
  17. +36
    -0
      tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc
  18. +23
    -1
      tests/ut/cpp/dataset/image_process_test.cc

+ 3
- 0
mindspore/ccsrc/minddata/dataset/api/vision.cc View File

@@ -134,6 +134,9 @@ std::shared_ptr<TensorOperation> CenterCrop::Parse(const MapTargetDevice &env) {
return std::make_shared<CenterCropOperation>(data_->size_); return std::make_shared<CenterCropOperation>(data_->size_);
} }


// RGB2GRAY Transform Operation.
std::shared_ptr<TensorOperation> RGB2GRAY::Parse() { return std::make_shared<RgbToGrayOperation>(); }

// Crop Transform Operation. // Crop Transform Operation.
struct Crop::Data { struct Crop::Data {
Data(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size) Data(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size)


+ 16
- 0
mindspore/ccsrc/minddata/dataset/include/vision_lite.h View File

@@ -91,6 +91,22 @@ class CenterCrop : public TensorTransform {
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };


/// \brief RGB2GRAY TensorTransform.
/// \notes Convert RGB image or color image to grayscale image
class RGB2GRAY : public TensorTransform {
public:
/// \brief Constructor.
RGB2GRAY() = default;

/// \brief Destructor.
~RGB2GRAY() = default;

protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
};

/// \brief Crop TensorTransform. /// \brief Crop TensorTransform.
/// \notes Crop an image based on location and crop size /// \notes Crop an image based on location and crop size
class Crop : public TensorTransform { class Crop : public TensorTransform {


+ 1
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt View File

@@ -44,6 +44,7 @@ add_library(kernels-image OBJECT
random_sharpness_op.cc random_sharpness_op.cc
rescale_op.cc rescale_op.cc
resize_op.cc resize_op.cc
rgb_to_gray_op.cc
rgba_to_bgr_op.cc rgba_to_bgr_op.cc
rgba_to_rgb_op.cc rgba_to_rgb_op.cc
sharpness_op.cc sharpness_op.cc


+ 17
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc View File

@@ -1096,6 +1096,23 @@ Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
} }
} }


Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("RgbToGray: image shape is not <H,W,C> or channel is not 3.");
}
TensorShape out_shape = TensorShape({input_cv->shape()[0], input_cv->shape()[1]});
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(out_shape, input_cv->type(), &output_cv));
cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast<int>(cv::COLOR_RGB2GRAY));
*output = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("RgbToGray: " + std::string(e.what()));
}
}

Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height) { Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height) {
struct jpeg_decompress_struct cinfo {}; struct jpeg_decompress_struct cinfo {};
struct JpegErrorManagerCustom jerr {}; struct JpegErrorManagerCustom jerr {};


+ 6
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h View File

@@ -293,6 +293,12 @@ Status RgbaToRgb(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
/// \return Status code /// \return Status code
Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output); Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);


/// \brief Take in a 3 channel image in RBG to GRAY
/// \param[in] input The input image
/// \param[out] output The output image
/// \return Status code
Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);

/// \brief Get jpeg image width and height /// \brief Get jpeg image width and height
/// \param input: CVTensor containing the not decoded image 1D bytes /// \param input: CVTensor containing the not decoded image 1D bytes
/// \param img_width: the jpeg image width /// \param img_width: the jpeg image width


+ 20
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.cc View File

@@ -1672,5 +1672,25 @@ bool GetAffineTransform(std::vector<Point> src_point, std::vector<Point> dst_poi
return true; return true;
} }


bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat) {
if (data_type == LDataType::UINT8) {
if (mat.IsEmpty()) {
mat.Init(w, h, 1, LDataType::UINT8);
}
unsigned char *ptr = mat;
const unsigned char *data_ptr = src;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
*ptr = (data_ptr[2] * B2GRAY + data_ptr[1] * G2GRAY + data_ptr[0] * R2GRAY + GRAYSHIFT_DELTA) >> GRAYSHIFT;
ptr++;
data_ptr += 3;
}
}
} else {
return false;
}
return true;
}

} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 3
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.h View File

@@ -137,6 +137,9 @@ bool ConvRowCol(const LiteMat &src, const LiteMat &kx, const LiteMat &ky, LiteMa
/// \brief Filter the image by a Sobel kernel /// \brief Filter the image by a Sobel kernel
bool Sobel(const LiteMat &src, LiteMat &dst, int flag_x, int flag_y, int ksize, PaddBorderType pad_type); bool Sobel(const LiteMat &src, LiteMat &dst, int flag_x, int flag_y, int ksize, PaddBorderType pad_type);


/// \brief Convert RGB image or color image to grayscale image
bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat);

} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // IMAGE_PROCESS_H_ #endif // IMAGE_PROCESS_H_

+ 34
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc View File

@@ -421,6 +421,40 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
return Status::OK(); return Status::OK();
} }


Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
if (input->Rank() != 3) {
RETURN_STATUS_UNEXPECTED("RgbToGray: input image is not in shape of <H,W,C>");
}
if (input->type() != DataType::DE_UINT8) {
RETURN_STATUS_UNEXPECTED("RgbToGray: image datatype is not uint8.");
}

try {
int output_height = input->shape()[0];
int output_width = input->shape()[1];

LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2],
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
GetLiteCVDataType(input->type()));
LiteMat lite_mat_convert;
std::shared_ptr<Tensor> output_tensor;
TensorShape new_shape = TensorShape({output_height, output_width, 1});
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), &output_tensor));
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*output_tensor->begin<uint8_t>()));
lite_mat_convert.Init(output_width, output_height, 1, reinterpret_cast<void *>(buffer),
GetLiteCVDataType(input->type()));

bool ret =
ConvertRgbToGray(lite_mat_rgb, GetLiteCVDataType(input->type()), output_width, output_height, lite_mat_convert);
CHECK_FAIL_RETURN_UNEXPECTED(ret, "RgbToGray: RGBToGRAY failed.");

*output = output_tensor;
} catch (std::runtime_error &e) {
RETURN_STATUS_UNEXPECTED("RgbToGray: " + std::string(e.what()));
}
return Status::OK();
}

Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const int32_t &pad_top, Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const int32_t &pad_top,
const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types,
uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) {


+ 6
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h View File

@@ -95,6 +95,12 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
int32_t output_width, double fx = 0.0, double fy = 0.0, int32_t output_width, double fx = 0.0, double fy = 0.0,
InterpolationMode mode = InterpolationMode::kLinear); InterpolationMode mode = InterpolationMode::kLinear);


/// \brief Take in a 3 channel image in RBG to GRAY
/// \param[in] input The input image
/// \param[out] output The output image
/// \return Status code
Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);

/// \brief Pads the input image and puts the padded image in the output /// \brief Pads the input image and puts the padded image in the output
/// \param[in] input: input Tensor /// \param[in] input: input Tensor
/// \param[out] output: padded Tensor /// \param[out] output: padded Tensor


+ 32
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.cc View File

@@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/rgb_to_gray_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif

namespace mindspore {
namespace dataset {

Status RgbToGrayOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
return RgbToGray(input, output);
}

} // namespace dataset
} // namespace mindspore

+ 42
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_

#include <memory>
#include <vector>
#include <string>

#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {
class RgbToGrayOp : public TensorOp {
public:
RgbToGrayOp() = default;

~RgbToGrayOp() override = default;

Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;

std::string Name() const override { return kRgbToGrayOp; }
};
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_

+ 6
- 0
mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc View File

@@ -72,6 +72,7 @@
#include "minddata/dataset/kernels/image/rgba_to_bgr_op.h" #include "minddata/dataset/kernels/image/rgba_to_bgr_op.h"
#include "minddata/dataset/kernels/image/rgba_to_rgb_op.h" #include "minddata/dataset/kernels/image/rgba_to_rgb_op.h"
#endif #endif
#include "minddata/dataset/kernels/image/rgb_to_gray_op.h"
#include "minddata/dataset/kernels/image/rotate_op.h" #include "minddata/dataset/kernels/image/rotate_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_random_crop_resize_jpeg_op.h" #include "minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_random_crop_resize_jpeg_op.h"
@@ -232,6 +233,11 @@ Status CenterCropOperation::to_json(nlohmann::json *out_json) {
return Status::OK(); return Status::OK();
} }


// RGB2GRAYOperation
Status RgbToGrayOperation::ValidateParams() { return Status::OK(); }

std::shared_ptr<TensorOp> RgbToGrayOperation::Build() { return std::make_shared<RgbToGrayOp>(); }

// CropOperation. // CropOperation.
CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size) CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size)
: coordinates_(coordinates), size_(size) {} : coordinates_(coordinates), size_(size) {}


+ 14
- 0
mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h View File

@@ -74,6 +74,7 @@ constexpr char kResizeOperation[] = "Resize";
constexpr char kResizeWithBBoxOperation[] = "ResizeWithBBox"; constexpr char kResizeWithBBoxOperation[] = "ResizeWithBBox";
constexpr char kRgbaToBgrOperation[] = "RgbaToBgr"; constexpr char kRgbaToBgrOperation[] = "RgbaToBgr";
constexpr char kRgbaToRgbOperation[] = "RgbaToRgb"; constexpr char kRgbaToRgbOperation[] = "RgbaToRgb";
constexpr char kRgbToGrayOperation[] = "RgbToGray";
constexpr char kRotateOperation[] = "Rotate"; constexpr char kRotateOperation[] = "Rotate";
constexpr char kSoftDvppDecodeRandomCropResizeJpegOperation[] = "SoftDvppDecodeRandomCropResizeJpeg"; constexpr char kSoftDvppDecodeRandomCropResizeJpegOperation[] = "SoftDvppDecodeRandomCropResizeJpeg";
constexpr char kSoftDvppDecodeResizeJpegOperation[] = "SoftDvppDecodeResizeJpeg"; constexpr char kSoftDvppDecodeResizeJpegOperation[] = "SoftDvppDecodeResizeJpeg";
@@ -163,6 +164,19 @@ class CenterCropOperation : public TensorOperation {
std::vector<int32_t> size_; std::vector<int32_t> size_;
}; };


class RgbToGrayOperation : public TensorOperation {
public:
RgbToGrayOperation() = default;

~RgbToGrayOperation() = default;

std::shared_ptr<TensorOp> Build() override;

Status ValidateParams() override;

std::string Name() const override { return kRgbToGrayOperation; }
};

class CropOperation : public TensorOperation { class CropOperation : public TensorOperation {
public: public:
CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size); CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h View File

@@ -97,6 +97,7 @@ constexpr char kResizeOp[] = "ResizeOp";
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp"; constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp";
constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp"; constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp";
constexpr char kRgbToGrayOp[] = "RgbToGrayOp";
constexpr char kSharpnessOp[] = "SharpnessOp"; constexpr char kSharpnessOp[] = "SharpnessOp";
constexpr char kSolarizeOp[] = "SolarizeOp"; constexpr char kSolarizeOp[] = "SolarizeOp";
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";


+ 16
- 0
mindspore/ccsrc/minddata/dataset/liteapi/include/vision_lite.h View File

@@ -88,6 +88,22 @@ class CenterCrop : public TensorTransform {
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };


/// \brief RGB2GRAY TensorTransform.
/// \notes Convert RGB image or color image to grayscale image
class RGB2GRAY : public TensorTransform {
public:
/// \brief Constructor.
RGB2GRAY() = default;

/// \brief Destructor.
~RGB2GRAY() = default;

protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
};

/// \brief Crop TensorTransform. /// \brief Crop TensorTransform.
/// \notes Crop an image based on location and crop size /// \notes Crop an image based on location and crop size
class Crop : public TensorTransform { class Crop : public TensorTransform {


+ 4
- 1
mindspore/lite/minddata/CMakeLists.txt View File

@@ -191,13 +191,14 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/util/cond_var.cc ${MINDDATA_DIR}/util/cond_var.cc
${MINDDATA_DIR}/engine/data_schema.cc ${MINDDATA_DIR}/engine/data_schema.cc
${MINDDATA_DIR}/kernels/tensor_op.cc ${MINDDATA_DIR}/kernels/tensor_op.cc
${MINDDATA_DIR}/kernels/image/affine_op.cc
${MINDDATA_DIR}/kernels/image/lite_image_utils.cc ${MINDDATA_DIR}/kernels/image/lite_image_utils.cc
${MINDDATA_DIR}/kernels/image/center_crop_op.cc ${MINDDATA_DIR}/kernels/image/center_crop_op.cc
${MINDDATA_DIR}/kernels/image/crop_op.cc ${MINDDATA_DIR}/kernels/image/crop_op.cc
${MINDDATA_DIR}/kernels/image/decode_op.cc ${MINDDATA_DIR}/kernels/image/decode_op.cc
${MINDDATA_DIR}/kernels/image/normalize_op.cc ${MINDDATA_DIR}/kernels/image/normalize_op.cc
${MINDDATA_DIR}/kernels/image/affine_op.cc
${MINDDATA_DIR}/kernels/image/resize_op.cc ${MINDDATA_DIR}/kernels/image/resize_op.cc
${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc
${MINDDATA_DIR}/kernels/image/rotate_op.cc ${MINDDATA_DIR}/kernels/image/rotate_op.cc
${MINDDATA_DIR}/kernels/image/random_affine_op.cc ${MINDDATA_DIR}/kernels/image/random_affine_op.cc
${MINDDATA_DIR}/kernels/image/math_utils.cc ${MINDDATA_DIR}/kernels/image/math_utils.cc
@@ -279,6 +280,7 @@ elseif(BUILD_MINDDATA STREQUAL "wrapper")
${MINDDATA_DIR}/kernels/image/crop_op.cc ${MINDDATA_DIR}/kernels/image/crop_op.cc
${MINDDATA_DIR}/kernels/image/normalize_op.cc ${MINDDATA_DIR}/kernels/image/normalize_op.cc
${MINDDATA_DIR}/kernels/image/resize_op.cc ${MINDDATA_DIR}/kernels/image/resize_op.cc
${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc
${MINDDATA_DIR}/kernels/image/rotate_op.cc ${MINDDATA_DIR}/kernels/image/rotate_op.cc
${MINDDATA_DIR}/kernels/data/compose_op.cc ${MINDDATA_DIR}/kernels/data/compose_op.cc
${MINDDATA_DIR}/kernels/data/duplicate_op.cc ${MINDDATA_DIR}/kernels/data/duplicate_op.cc
@@ -377,6 +379,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
"${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc" "${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc" "${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc"
"${MINDDATA_DIR}/kernels/image/rescale_op.cc" "${MINDDATA_DIR}/kernels/image/rescale_op.cc"
"${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc"
"${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc" "${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc"
"${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc" "${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc"
"${MINDDATA_DIR}/kernels/image/sharpness_op.cc" "${MINDDATA_DIR}/kernels/image/sharpness_op.cc"


+ 36
- 0
tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc View File

@@ -195,3 +195,39 @@ TEST_F(MindDataTestPipeline, TestResizeWithBBoxSuccess) {
// Manually terminate the pipeline // Manually terminate the pipeline
iter->Stop(); iter->Stop();
} }

TEST_F(MindDataTestPipeline, TestRGB2GRAYSucess) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRGB2GRAYSucess.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<SequentialSampler>(0, 1));
EXPECT_NE(ds, nullptr);

// Create objects for the tensor ops
std::shared_ptr<TensorTransform> convert(new mindspore::dataset::vision::RGB2GRAY());

ds = ds->Map({convert});
EXPECT_NE(ds, nullptr);

// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);

// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
iter->GetNextRow(&row);

uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
iter->GetNextRow(&row);
}

EXPECT_EQ(i, 1);

// Manually terminate the pipeline
iter->Stop();
}

+ 23
- 1
tests/ut/cpp/dataset/image_process_test.cc View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "common/common.h" #include "common/common.h"
#include "lite_cv/lite_mat.h" #include "lite_cv/lite_mat.h"
#include "lite_cv/image_process.h" #include "lite_cv/image_process.h"
@@ -1714,3 +1714,25 @@ TEST_F(MindDataImageProcess, TestSobelFlag) {
distance_x = sqrt(distance_x / total_size); distance_x = sqrt(distance_x / total_size);
EXPECT_EQ(distance_x, 0.0f); EXPECT_EQ(distance_x, 0.0f);
} }

TEST_F(MindDataImageProcess, testConvertRgbToGray) {
std::string filename = "data/dataset/apple.jpg";
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
cv::Mat rgb_mat;
cv::Mat rgb_mat1;

cv::cvtColor(image, rgb_mat, CV_BGR2GRAY);
cv::imwrite("./opencv_image.jpg", rgb_mat);

cv::cvtColor(image, rgb_mat1, CV_BGR2RGB);

LiteMat lite_mat_rgb;
lite_mat_rgb.Init(rgb_mat1.cols, rgb_mat1.rows, rgb_mat1.channels(), rgb_mat1.data, LDataType::UINT8);
LiteMat lite_mat_gray;
bool ret = ConvertRgbToGray(lite_mat_rgb, LDataType::UINT8, image.cols, image.rows, lite_mat_gray);
ASSERT_TRUE(ret == true);

cv::Mat dst_image(lite_mat_gray.height_, lite_mat_gray.width_, CV_8UC1, lite_mat_gray.data_ptr_);
cv::imwrite("./mindspore_image.jpg", dst_image);
CompareMat(rgb_mat, lite_mat_gray);
}

Loading…
Cancel
Save