Browse Source

!11249 【MD】fix bug for extractchannel when type is UINT16

From: @xulei2020
Reviewed-by: @heleiwang,@liucunwei
Signed-off-by: @liucunwei
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
dfa6daaa57
2 changed files with 13 additions and 5 deletions
  1. +3
    -3
      mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.cc
  2. +10
    -2
      tests/ut/cpp/dataset/image_process_test.cc

+ 3
- 3
mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.cc View File

@@ -674,16 +674,16 @@ bool ExtractChannel(LiteMat &src, LiteMat &dst, int col) {
return false; return false;
} }


if (dst.data_type_ == LDataType::FLOAT32 || dst.data_type_ == LDataType::UINT8) {
if (src.data_type_ == LDataType::FLOAT32 || src.data_type_ == LDataType::UINT8) {
if (dst.IsEmpty() || dst.width_ != src.width_ || dst.height_ != src.height_ || dst.channel_ != 1 || if (dst.IsEmpty() || dst.width_ != src.width_ || dst.height_ != src.height_ || dst.channel_ != 1 ||
dst.data_type_ != src.data_type_) { dst.data_type_ != src.data_type_) {
dst.Init(src.width_, src.height_, 1, src.data_type_); dst.Init(src.width_, src.height_, 1, src.data_type_);
} }
} }


if (dst.data_type_ == LDataType::FLOAT32) {
if (src.data_type_ == LDataType::FLOAT32) {
ExtractChannelImpl<float>(src, dst, src.height_, src.width_, src.channel_, col); ExtractChannelImpl<float>(src, dst, src.height_, src.width_, src.channel_, col);
} else if (dst.data_type_ == LDataType::UINT8) {
} else if (src.data_type_ == LDataType::UINT8) {
ExtractChannelImpl<uint8_t>(src, dst, src.height_, src.width_, src.channel_, col); ExtractChannelImpl<uint8_t>(src, dst, src.height_, src.width_, src.channel_, col);
} else { } else {
return false; return false;


+ 10
- 2
tests/ut/cpp/dataset/image_process_test.cc View File

@@ -408,7 +408,7 @@ TEST_F(MindDataImageProcess, TestPadd) {
size_t total_size = makeborder.height_ * makeborder.width_ * makeborder.channel_; size_t total_size = makeborder.height_ * makeborder.width_ * makeborder.channel_;
double distance = 0.0f; double distance = 0.0f;
for (size_t i = 0; i < total_size; i++) { for (size_t i = 0; i < total_size; i++) {
distance += pow((uint8_t)b_image.data[i] - ((uint8_t*)makeborder)[i], 2);
distance += pow((uint8_t)b_image.data[i] - ((uint8_t *)makeborder)[i], 2);
} }
distance = sqrt(distance / total_size); distance = sqrt(distance / total_size);
EXPECT_EQ(distance, 0.0f); EXPECT_EQ(distance, 0.0f);
@@ -439,7 +439,7 @@ TEST_F(MindDataImageProcess, TestPadZero) {
size_t total_size = makeborder.height_ * makeborder.width_ * makeborder.channel_; size_t total_size = makeborder.height_ * makeborder.width_ * makeborder.channel_;
double distance = 0.0f; double distance = 0.0f;
for (size_t i = 0; i < total_size; i++) { for (size_t i = 0; i < total_size; i++) {
distance += pow((uint8_t)b_image.data[i] - ((uint8_t*)makeborder)[i], 2);
distance += pow((uint8_t)b_image.data[i] - ((uint8_t *)makeborder)[i], 2);
} }
distance = sqrt(distance / total_size); distance = sqrt(distance / total_size);
EXPECT_EQ(distance, 0.0f); EXPECT_EQ(distance, 0.0f);
@@ -879,3 +879,11 @@ TEST_F(MindDataImageProcess, TestMultiplyFloat) {
static_cast<FLOAT32_C1 *>(dst_float.data_ptr_)[i].c1); static_cast<FLOAT32_C1 *>(dst_float.data_ptr_)[i].c1);
} }
} }

TEST_F(MindDataImageProcess, TestExtractChannel) {
LiteMat lite_single;
LiteMat lite_mat = LiteMat(1, 4, 3, LDataType::UINT16);

EXPECT_FALSE(ExtractChannel(lite_mat, lite_single, 0));
EXPECT_TRUE(lite_single.IsEmpty());
}

Loading…
Cancel
Save