| @@ -232,6 +232,17 @@ bool ResizeBilinear(const LiteMat &src, LiteMat &dst, int dst_w, int dst_h) { | |||
| return true; | |||
| } | |||
| static bool ConvertBGR(const unsigned char *data, LDataType data_type, int w, int h, LiteMat &mat) { | |||
| if (data_type == LDataType::UINT8) { | |||
| mat.Init(w, h, 3, LDataType::UINT8); | |||
| unsigned char *dst_ptr = mat; | |||
| (void)memcpy(dst_ptr, data, w * h * 3 * sizeof(unsigned char)); | |||
| } else { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| static bool ConvertRGBAToBGR(const unsigned char *data, LDataType data_type, int w, int h, LiteMat &mat) { | |||
| if (data_type == LDataType::UINT8) { | |||
| mat.Init(w, h, 3, LDataType::UINT8); | |||
| @@ -272,6 +283,76 @@ static bool ConvertRGBAToRGB(const unsigned char *data, LDataType data_type, int | |||
| return true; | |||
| } | |||
| static bool ConvertYUV420SPToBGR(const uint8_t *data, LDataType data_type, bool flag, int w, int h, LiteMat &mat) { | |||
| if (data == nullptr || w <= 0 || h <= 0) { | |||
| return false; | |||
| } | |||
| if (data_type == LDataType::UINT8) { | |||
| mat.Init(w, h, 3, LDataType::UINT8); | |||
| const uint8_t *y_ptr = data; | |||
| const uint8_t *uv_ptr = y_ptr + w * h; | |||
| uint8_t *bgr_ptr = mat; | |||
| int bgr_stride = 3 * w; | |||
| for (int y = 0; y < h; ++y) { | |||
| uint8_t *bgr_buf = bgr_ptr; | |||
| const uint8_t *uv_buf = uv_ptr; | |||
| const uint8_t *y_buf = y_ptr; | |||
| uint8_t u; | |||
| uint8_t v; | |||
| for (int x = 0; x < w - 1; x += 2) { | |||
| if (flag) { | |||
| // NV21 | |||
| u = uv_buf[1]; | |||
| v = uv_buf[0]; | |||
| } else { | |||
| // NV12 | |||
| u = uv_buf[0]; | |||
| v = uv_buf[1]; | |||
| } | |||
| uint32_t tmp_y = (uint32_t)(y_buf[0] * YSCALE * YTOG) >> 16; | |||
| // b | |||
| bgr_buf[0] = std::clamp((int32_t)(-(u * UTOB) + tmp_y + BTOB) >> 6, 0, 255); | |||
| // g | |||
| bgr_buf[1] = std::clamp((int32_t)(-(u * UTOG + v * VTOG) + tmp_y + BTOG) >> 6, 0, 255); | |||
| // r | |||
| bgr_buf[2] = std::clamp((int32_t)(-(v * VTOR) + tmp_y + BTOR) >> 6, 0, 255); | |||
| tmp_y = (uint32_t)(y_buf[1] * YSCALE * YTOG) >> 16; | |||
| bgr_buf[3] = std::clamp((int32_t)(-(u * UTOB) + tmp_y + BTOB) >> 6, 0, 255); | |||
| bgr_buf[4] = std::clamp((int32_t)(-(u * UTOG + v * VTOG) + tmp_y + BTOG) >> 6, 0, 255); | |||
| bgr_buf[5] = std::clamp((int32_t)(-(v * VTOR) + tmp_y + BTOR) >> 6, 0, 255); | |||
| y_buf += 2; | |||
| uv_buf += 2; | |||
| bgr_buf += 6; | |||
| } | |||
| if (w & 1) { | |||
| if (flag) { | |||
| // NV21 | |||
| u = uv_buf[1]; | |||
| v = uv_buf[0]; | |||
| } else { | |||
| // NV12 | |||
| u = uv_buf[0]; | |||
| v = uv_buf[1]; | |||
| } | |||
| uint32_t tmp_y = (uint32_t)(y_buf[0] * YSCALE * YTOG) >> 16; | |||
| bgr_buf[0] = std::clamp((int32_t)(-(u * UTOB) + tmp_y + BTOB) >> 6, 0, 255); | |||
| bgr_buf[1] = std::clamp((int32_t)(-(u * UTOG + v * VTOG) + tmp_y + BTOG) >> 6, 0, 255); | |||
| bgr_buf[2] = std::clamp((int32_t)(-(v * VTOR) + tmp_y + BTOR) >> 6, 0, 255); | |||
| } | |||
| bgr_ptr += bgr_stride; | |||
| y_ptr += w; | |||
| if (y & 1) { | |||
| uv_ptr += w; | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| static bool ConvertRGBAToGRAY(const unsigned char *data, LDataType data_type, int w, int h, LiteMat &mat) { | |||
| if (data_type == LDataType::UINT8) { | |||
| mat.Init(w, h, 1, LDataType::UINT8); | |||
| @@ -300,12 +381,24 @@ bool InitFromPixel(const unsigned char *data, LPixelType pixel_type, LDataType d | |||
| if (w <= 0 || h <= 0) { | |||
| return false; | |||
| } | |||
| if (data_type != LDataType::UINT8) { | |||
| return false; | |||
| } | |||
| if (pixel_type == LPixelType::RGBA2BGR) { | |||
| return ConvertRGBAToBGR(data, data_type, w, h, m); | |||
| } else if (pixel_type == LPixelType::RGBA2GRAY) { | |||
| return ConvertRGBAToGRAY(data, data_type, w, h, m); | |||
| } else if (pixel_type == LPixelType::RGBA2RGB) { | |||
| return ConvertRGBAToRGB(data, data_type, w, h, m); | |||
| } else if (pixel_type == LPixelType::NV212BGR) { | |||
| return ConvertYUV420SPToBGR(data, data_type, true, w, h, m); | |||
| } else if (pixel_type == LPixelType::NV122BGR) { | |||
| return ConvertYUV420SPToBGR(data, data_type, false, w, h, m); | |||
| } else if (pixel_type == LPixelType::BGR) { | |||
| return ConvertBGR(data, data_type, w, h, m); | |||
| } else if (pixel_type == LPixelType::RGB) { | |||
| return ConvertBGR(data, data_type, w, h, m); | |||
| } else { | |||
| return false; | |||
| } | |||
| @@ -322,8 +415,8 @@ bool ConvertTo(const LiteMat &src, LiteMat &dst, double scale) { | |||
| float *dst_start_p = dst; | |||
| for (int h = 0; h < src.height_; h++) { | |||
| for (int w = 0; w < src.width_; w++) { | |||
| uint32_t index = (h * src.width_ + w) * src.channel_; | |||
| for (int c = 0; c < src.channel_; c++) { | |||
| int index = (h * src.width_ + w) * src.channel_; | |||
| dst_start_p[index + c] = (static_cast<float>(src_start_p[index + c] * scale)); | |||
| } | |||
| } | |||
| @@ -418,8 +511,9 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector< | |||
| if ((!mean.empty()) && std.empty()) { | |||
| for (int h = 0; h < src.height_; h++) { | |||
| for (int w = 0; w < src.width_; w++) { | |||
| uint32_t src_start = (h * src.width_ + w) * src.channel_; | |||
| for (int c = 0; c < src.channel_; c++) { | |||
| int index = (h * src.width_ + w) * src.channel_ + c; | |||
| uint32_t index = src_start + c; | |||
| dst_start_p[index] = src_start_p[index] - mean[c]; | |||
| } | |||
| } | |||
| @@ -427,8 +521,9 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector< | |||
| } else if (mean.empty() && (!std.empty())) { | |||
| for (int h = 0; h < src.height_; h++) { | |||
| for (int w = 0; w < src.width_; w++) { | |||
| uint32_t src_start = (h * src.width_ + w) * src.channel_; | |||
| for (int c = 0; c < src.channel_; c++) { | |||
| int index = (h * src.width_ + w) * src.channel_ + c; | |||
| uint32_t index = src_start + c; | |||
| dst_start_p[index] = src_start_p[index] / std[c]; | |||
| } | |||
| } | |||
| @@ -436,8 +531,9 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector< | |||
| } else if ((!mean.empty()) && (!std.empty())) { | |||
| for (int h = 0; h < src.height_; h++) { | |||
| for (int w = 0; w < src.width_; w++) { | |||
| uint32_t src_start = (h * src.width_ + w) * src.channel_; | |||
| for (int c = 0; c < src.channel_; c++) { | |||
| int index = (h * src.width_ + w) * src.channel_ + c; | |||
| uint32_t index = src_start + c; | |||
| dst_start_p[index] = (src_start_p[index] - mean[c]) / std[c]; | |||
| } | |||
| } | |||
| @@ -458,7 +554,7 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con | |||
| // padd top | |||
| for (int h = 0; h < top; h++) { | |||
| for (int w = 0; w < dst.width_; w++) { | |||
| int index = (h * dst.width_ + w) * dst.channel_; | |||
| uint32_t index = (h * dst.width_ + w) * dst.channel_; | |||
| if (dst.channel_ == 1) { | |||
| dst_start_p[index] = fill_b_or_gray; | |||
| } else if (dst.channel_ == 3) { | |||
| @@ -472,7 +568,7 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con | |||
| // padd bottom | |||
| for (int h = dst.height_ - bottom; h < dst.height_; h++) { | |||
| for (int w = 0; w < dst.width_; w++) { | |||
| int index = (h * dst.width_ + w) * dst.channel_; | |||
| uint32_t index = (h * dst.width_ + w) * dst.channel_; | |||
| if (dst.channel_ == 1) { | |||
| dst_start_p[index] = fill_b_or_gray; | |||
| } else if (dst.channel_ == 3) { | |||
| @@ -487,7 +583,7 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con | |||
| // padd left | |||
| for (int h = top; h < dst.height_ - bottom; h++) { | |||
| for (int w = 0; w < left; w++) { | |||
| int index = (h * dst.width_ + w) * dst.channel_; | |||
| uint32_t index = (h * dst.width_ + w) * dst.channel_; | |||
| if (dst.channel_ == 1) { | |||
| dst_start_p[index] = fill_b_or_gray; | |||
| } else if (dst.channel_ == 3) { | |||
| @@ -502,7 +598,7 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con | |||
| // padd right | |||
| for (int h = top; h < dst.height_ - bottom; h++) { | |||
| for (int w = dst.width_ - right; w < dst.width_; w++) { | |||
| int index = (h * dst.width_ + w) * dst.channel_; | |||
| uint32_t index = (h * dst.width_ + w) * dst.channel_; | |||
| if (dst.channel_ == 1) { | |||
| dst_start_p[index] = fill_b_or_gray; | |||
| } else if (dst.channel_ == 3) { | |||
| @@ -522,6 +618,86 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con | |||
| } | |||
| } | |||
| bool ExtractChannel(const LiteMat &src, LiteMat &dst, int col) { | |||
| if (src.IsEmpty() || col < 0 || col > src.dims_ - 1) { | |||
| return false; | |||
| } | |||
| (void)dst.Init(src.width_, src.height_, 1, src.data_type_); | |||
| if (src.data_type_ == LDataType::FLOAT32) { | |||
| const float *src_start_p = src; | |||
| float *dst_start_p = dst; | |||
| for (int h = 0; h < src.height_; h++) { | |||
| uint32_t src_start = h * src.width_ * src.channel_ + col; | |||
| uint32_t dst_start = h * dst.width_; | |||
| for (int w = 0; w < src.width_; w++) { | |||
| uint32_t src_index = src_start + w * src.channel_; | |||
| uint32_t dst_index = dst_start + w; | |||
| dst_start_p[dst_index] = src_start_p[src_index]; | |||
| } | |||
| } | |||
| return true; | |||
| } else if (src.data_type_ == LDataType::UINT8) { | |||
| const uint8_t *src_start_p = src; | |||
| uint8_t *dst_start_p = dst; | |||
| for (int h = 0; h < src.height_; h++) { | |||
| uint32_t src_start = h * src.width_ * src.channel_ + col; | |||
| uint32_t dst_start = h * dst.width_; | |||
| for (int w = 0; w < src.width_; w++) { | |||
| uint32_t src_index = src_start + w * src.channel_; | |||
| uint32_t dst_index = dst_start + w; | |||
| dst_start_p[dst_index] = src_start_p[src_index]; | |||
| } | |||
| } | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| return false; | |||
| } | |||
| bool Split(const LiteMat &src, std::vector<LiteMat> &mv) { | |||
| if (src.data_type_ == LDataType::FLOAT32) { | |||
| const float *src_start_p = src; | |||
| for (int c = 0; c < src.channel_; c++) { | |||
| LiteMat dst; | |||
| (void)dst.Init(src.width_, src.height_, 1, src.data_type_); | |||
| float *dst_start_p = dst; | |||
| for (int h = 0; h < src.height_; h++) { | |||
| uint32_t src_start = h * src.width_ * src.channel_; | |||
| uint32_t dst_start = h * dst.width_; | |||
| for (int w = 0; w < src.width_; w++) { | |||
| uint32_t src_index = src_start + w * src.channel_ + c; | |||
| uint32_t dst_index = dst_start + w; | |||
| dst_start_p[dst_index] = src_start_p[src_index]; | |||
| } | |||
| } | |||
| mv.emplace_back(dst); | |||
| } | |||
| return true; | |||
| } else if (src.data_type_ == LDataType::UINT8) { | |||
| const uint8_t *src_start_p = src; | |||
| for (int c = 0; c < src.channel_; c++) { | |||
| LiteMat dst; | |||
| (void)dst.Init(src.width_, src.height_, 1, src.data_type_); | |||
| uint8_t *dst_start_p = dst; | |||
| for (int h = 0; h < src.height_; h++) { | |||
| uint32_t src_start = h * src.width_ * src.channel_; | |||
| uint32_t dst_start = h * dst.width_; | |||
| for (int w = 0; w < src.width_; w++) { | |||
| uint32_t src_index = src_start + w * src.channel_ + c; | |||
| uint32_t dst_index = dst_start + w; | |||
| dst_start_p[dst_index] = src_start_p[src_index]; | |||
| } | |||
| } | |||
| mv.emplace_back(dst); | |||
| } | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| return false; | |||
| } | |||
| bool Pad(const LiteMat &src, LiteMat &dst, int top, int bottom, int left, int right, PaddBorderType pad_type, | |||
| uint8_t fill_b_or_gray, uint8_t fill_g, uint8_t fill_r) { | |||
| if (top <= 0 || bottom <= 0 || left <= 0 || right <= 0) { | |||
| @@ -35,6 +35,17 @@ namespace dataset { | |||
| #define B2GRAY 29 | |||
| #define GRAYSHIFT 8 | |||
| #define YSCALE 0x0101 | |||
| #define UTOB (-128) | |||
| #define UTOG 25 | |||
| #define VTOR (-102) | |||
| #define VTOG 52 | |||
| #define YTOG 18997 | |||
| #define YTOGB (-1160) | |||
| #define BTOB (UTOB * 128 + YTOGB) | |||
| #define BTOG (UTOG * 128 + VTOG * 128 + YTOGB) | |||
| #define BTOR (VTOR * 128 + YTOGB) | |||
| enum PaddBorderType { PADD_BORDER_CONSTANT = 0, PADD_BORDER_REPLICATE = 1 }; | |||
| struct BoxesConfig { | |||
| @@ -70,6 +81,10 @@ bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector< | |||
| bool Pad(const LiteMat &src, LiteMat &dst, int top, int bottom, int left, int right, PaddBorderType pad_type, | |||
| uint8_t fill_b_or_gray, uint8_t fill_g, uint8_t fill_r); | |||
| bool ExtractChannel(const LiteMat &src, LiteMat &dst, int col); | |||
| bool Split(const LiteMat &src, std::vector<LiteMat> &mv); | |||
| /// \brief Apply affine transformation for 1 channel image | |||
| bool Affine(LiteMat &src, LiteMat &out_img, double M[6], std::vector<size_t> dsize, UINT8_C1 borderValue); | |||
| @@ -92,6 +92,8 @@ enum LPixelType { | |||
| RGBA2GRAY = 3, | |||
| RGBA2BGR = 4, | |||
| RGBA2RGB = 5, | |||
| NV212BGR = 6, | |||
| NV122BGR = 7, | |||
| }; | |||
| class LDataType { | |||
| @@ -159,7 +161,6 @@ class LDataType { | |||
| class LiteMat { | |||
| // Class that represents a lite Mat of a Image. | |||
| // -# The pixel type of Lite Mat is RGBRGB...RGB. | |||
| public: | |||
| LiteMat(); | |||
| @@ -19,7 +19,6 @@ | |||
| #include "lite_cv/image_process.h" | |||
| #include <opencv2/opencv.hpp> | |||
| #include <opencv2/imgproc/types_c.h> | |||
| #include "utils/log_adapter.h" | |||
| #include <fstream> | |||
| @@ -43,32 +42,22 @@ void CompareMat(cv::Mat cv_mat, LiteMat lite_mat) { | |||
| ASSERT_TRUE(cv_c == lite_c); | |||
| } | |||
| LiteMat Lite3CImageProcess(LiteMat &lite_mat_bgr) { | |||
| void Lite3CImageProcess(LiteMat &lite_mat_bgr, LiteMat &lite_norm_mat_cut) { | |||
| bool ret; | |||
| LiteMat lite_mat_resize; | |||
| ret = ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "ResizeBilinear error"; | |||
| } | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_mat_convert_float; | |||
| ret = ConvertTo(lite_mat_resize, lite_mat_convert_float, 1.0); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "ConvertTo error"; | |||
| } | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_mat_crop; | |||
| ret = Crop(lite_mat_convert_float, lite_mat_crop, 16, 16, 224, 224); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Crop error"; | |||
| } | |||
| ASSERT_TRUE(ret == true); | |||
| std::vector<float> means = {0.485, 0.456, 0.406}; | |||
| std::vector<float> stds = {0.229, 0.224, 0.225}; | |||
| LiteMat lite_norm_mat_cut; | |||
| SubStractMeanNormalize(lite_mat_crop, lite_norm_mat_cut, means, stds); | |||
| return lite_norm_mat_cut; | |||
| return; | |||
| } | |||
| cv::Mat cv3CImageProcess(cv::Mat &image) { | |||
| @@ -103,11 +92,25 @@ cv::Mat cv3CImageProcess(cv::Mat &image) { | |||
| return imgR2; | |||
| } | |||
| TEST_F(MindDataImageProcess, testRGB) { | |||
| std::string filename = "data/dataset/apple.jpg"; | |||
| cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); | |||
| cv::Mat rgba_mat; | |||
| cv::cvtColor(image, rgba_mat, CV_BGR2RGB); | |||
| bool ret = false; | |||
| LiteMat lite_mat_rgb; | |||
| ret = InitFromPixel(rgba_mat.data, LPixelType::RGB, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_rgb); | |||
| ASSERT_TRUE(ret == true); | |||
| cv::Mat dst_image(lite_mat_rgb.height_, lite_mat_rgb.width_, CV_8UC3, lite_mat_rgb.data_ptr_); | |||
| } | |||
| TEST_F(MindDataImageProcess, test3C) { | |||
| std::string filename = "data/dataset/apple.jpg"; | |||
| cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); | |||
| cv::Mat cv_image = cv3CImageProcess(image); | |||
| // cv::imwrite("/home/xlei/test_3cv.jpg", cv_image); | |||
| // convert to RGBA for Android bitmap(rgba) | |||
| cv::Mat rgba_mat; | |||
| @@ -117,34 +120,142 @@ TEST_F(MindDataImageProcess, test3C) { | |||
| LiteMat lite_mat_bgr; | |||
| ret = | |||
| InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Init From RGBA error"; | |||
| } | |||
| LiteMat lite_norm_mat_cut = Lite3CImageProcess(lite_mat_bgr); | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_norm_mat_cut; | |||
| Lite3CImageProcess(lite_mat_bgr, lite_norm_mat_cut); | |||
| cv::Mat dst_image(lite_norm_mat_cut.height_, lite_norm_mat_cut.width_, CV_32FC3, lite_norm_mat_cut.data_ptr_); | |||
| // cv::imwrite("/home/xlei/test_3clite.jpg", dst_image); | |||
| CompareMat(cv_image, lite_norm_mat_cut); | |||
| } | |||
| LiteMat Lite1CImageProcess(LiteMat &lite_mat_bgr) { | |||
| LiteMat lite_mat_resize; | |||
| ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256); | |||
| LiteMat lite_mat_convert_float; | |||
| ConvertTo(lite_mat_resize, lite_mat_convert_float); | |||
| bool ReadYUV(const char *filename, int w, int h, uint8_t **data) { | |||
| FILE *f = fopen(filename, "rb"); | |||
| if (f == nullptr) { | |||
| return false; | |||
| } | |||
| fseek(f, 0, SEEK_END); | |||
| int size = ftell(f); | |||
| int expect_size = w * h + 2 * ((w + 1) / 2) * ((h + 1) / 2); | |||
| if (size != expect_size) { | |||
| fclose(f); | |||
| return false; | |||
| } | |||
| fseek(f, 0, SEEK_SET); | |||
| *data = (uint8_t *)malloc(size); | |||
| size_t re = fread(*data, 1, size, f); | |||
| if (re != size) { | |||
| fclose(f); | |||
| return false; | |||
| } | |||
| fclose(f); | |||
| return true; | |||
| } | |||
| LiteMat lite_mat_cut; | |||
| TEST_F(MindDataImageProcess, testNV21ToBGR) { | |||
| // ffmpeg -i ./data/dataset/apple.jpg -s 1024*800 -pix_fmt nv21 ./data/dataset/yuv/test_nv21.yuv | |||
| const char *filename = "data/dataset/yuv/test_nv21.yuv"; | |||
| int w = 1024; | |||
| int h = 800; | |||
| uint8_t *yuv_data = nullptr; | |||
| bool ret = ReadYUV(filename, w, h, &yuv_data); | |||
| ASSERT_TRUE(ret == true); | |||
| Crop(lite_mat_convert_float, lite_mat_cut, 16, 16, 224, 224); | |||
| cv::Mat yuvimg(h * 3 / 2, w, CV_8UC1); | |||
| memcpy(yuvimg.data, yuv_data, w * h * 3 / 2); | |||
| cv::Mat rgbimage; | |||
| std::vector<float> means = {0.485}; | |||
| std::vector<float> stds = {0.229}; | |||
| cv::cvtColor(yuvimg, rgbimage, cv::COLOR_YUV2BGR_NV21); | |||
| LiteMat lite_norm_mat_cut; | |||
| LiteMat lite_mat_bgr; | |||
| SubStractMeanNormalize(lite_mat_cut, lite_norm_mat_cut, means, stds); | |||
| return lite_norm_mat_cut; | |||
| ret = InitFromPixel(yuv_data, LPixelType::NV212BGR, LDataType::UINT8, w, h, lite_mat_bgr); | |||
| ASSERT_TRUE(ret == true); | |||
| cv::Mat dst_image(lite_mat_bgr.height_, lite_mat_bgr.width_, CV_8UC3, lite_mat_bgr.data_ptr_); | |||
| } | |||
| TEST_F(MindDataImageProcess, testNV12ToBGR) { | |||
| // ffmpeg -i ./data/dataset/apple.jpg -s 1024*800 -pix_fmt nv12 ./data/dataset/yuv/test_nv12.yuv | |||
| const char *filename = "data/dataset/yuv/test_nv12.yuv"; | |||
| int w = 1024; | |||
| int h = 800; | |||
| uint8_t *yuv_data = nullptr; | |||
| bool ret = ReadYUV(filename, w, h, &yuv_data); | |||
| ASSERT_TRUE(ret == true); | |||
| cv::Mat yuvimg(h * 3 / 2, w, CV_8UC1); | |||
| memcpy(yuvimg.data, yuv_data, w * h * 3 / 2); | |||
| cv::Mat rgbimage; | |||
| cv::cvtColor(yuvimg, rgbimage, cv::COLOR_YUV2BGR_NV12); | |||
| LiteMat lite_mat_bgr; | |||
| ret = InitFromPixel(yuv_data, LPixelType::NV122BGR, LDataType::UINT8, w, h, lite_mat_bgr); | |||
| ASSERT_TRUE(ret == true); | |||
| cv::Mat dst_image(lite_mat_bgr.height_, lite_mat_bgr.width_, CV_8UC3, lite_mat_bgr.data_ptr_); | |||
| } | |||
| TEST_F(MindDataImageProcess, testExtractChannel) { | |||
| std::string filename = "data/dataset/apple.jpg"; | |||
| cv::Mat src_image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); | |||
| cv::Mat dst_image; | |||
| cv::extractChannel(src_image, dst_image, 2); | |||
| // convert to RGBA for Android bitmap(rgba) | |||
| cv::Mat rgba_mat; | |||
| cv::cvtColor(src_image, rgba_mat, CV_BGR2RGBA); | |||
| bool ret = false; | |||
| LiteMat lite_mat_bgr; | |||
| ret = | |||
| InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr); | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_B; | |||
| ret = ExtractChannel(lite_mat_bgr, lite_B, 0); | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_R; | |||
| ret = ExtractChannel(lite_mat_bgr, lite_R, 2); | |||
| ASSERT_TRUE(ret == true); | |||
| cv::Mat dst_imageR(lite_R.height_, lite_R.width_, CV_8UC1, lite_R.data_ptr_); | |||
| // cv::imwrite("./test_lite_r.jpg", dst_imageR); | |||
| } | |||
| TEST_F(MindDataImageProcess, testSplit) { | |||
| std::string filename = "data/dataset/apple.jpg"; | |||
| cv::Mat src_image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); | |||
| std::vector<cv::Mat> dst_images; | |||
| cv::split(src_image, dst_images); | |||
| // convert to RGBA for Android bitmap(rgba) | |||
| cv::Mat rgba_mat; | |||
| cv::cvtColor(src_image, rgba_mat, CV_BGR2RGBA); | |||
| bool ret = false; | |||
| LiteMat lite_mat_bgr; | |||
| ret = | |||
| InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr); | |||
| ASSERT_TRUE(ret == true); | |||
| std::vector<LiteMat> lite_all; | |||
| ret = Split(lite_mat_bgr, lite_all); | |||
| ASSERT_TRUE(ret == true); | |||
| ASSERT_TRUE(lite_all.size() == 3); | |||
| LiteMat lite_r = lite_all[2]; | |||
| cv::Mat dst_imageR(lite_r.height_, lite_r.width_, CV_8UC1, lite_r.data_ptr_); | |||
| } | |||
| void Lite1CImageProcess(LiteMat &lite_mat_bgr, LiteMat &lite_norm_mat_cut) { | |||
| LiteMat lite_mat_resize; | |||
| int ret = ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256); | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_mat_convert_float; | |||
| ret = ConvertTo(lite_mat_resize, lite_mat_convert_float); | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_mat_cut; | |||
| ret = Crop(lite_mat_convert_float, lite_mat_cut, 16, 16, 224, 224); | |||
| ASSERT_TRUE(ret == true); | |||
| std::vector<float> means = {0.485}; | |||
| std::vector<float> stds = {0.229}; | |||
| ret = SubStractMeanNormalize(lite_mat_cut, lite_norm_mat_cut, means, stds); | |||
| ASSERT_TRUE(ret == true); | |||
| return; | |||
| } | |||
| cv::Mat cv1CImageProcess(cv::Mat &image) { | |||
| @@ -183,18 +294,17 @@ TEST_F(MindDataImageProcess, test1C) { | |||
| cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); | |||
| cv::Mat cv_image = cv1CImageProcess(image); | |||
| // cv::imwrite("/home/xlei/test_c1v.jpg", cv_image); | |||
| // convert to RGBA for Android bitmap(rgba) | |||
| cv::Mat rgba_mat; | |||
| cv::cvtColor(image, rgba_mat, CV_BGR2RGBA); | |||
| LiteMat lite_mat_bgr; | |||
| InitFromPixel(rgba_mat.data, LPixelType::RGBA2GRAY, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr); | |||
| LiteMat lite_norm_mat_cut = Lite1CImageProcess(lite_mat_bgr); | |||
| bool ret = | |||
| InitFromPixel(rgba_mat.data, LPixelType::RGBA2GRAY, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr); | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_norm_mat_cut; | |||
| Lite1CImageProcess(lite_mat_bgr, lite_norm_mat_cut); | |||
| cv::Mat dst_image(lite_norm_mat_cut.height_, lite_norm_mat_cut.width_, CV_32FC1, lite_norm_mat_cut.data_ptr_); | |||
| // cv::imwrite("/home/xlei/test_c1lite.jpg", dst_image); | |||
| CompareMat(cv_image, lite_norm_mat_cut); | |||
| } | |||
| @@ -211,22 +321,20 @@ TEST_F(MindDataImageProcess, TestPadd) { | |||
| cv::Mat b_image; | |||
| cv::Scalar color = cv::Scalar(255, 255, 255); | |||
| cv::copyMakeBorder(resize_256_image, b_image, top, bottom, left, right, cv::BORDER_CONSTANT, color); | |||
| // cv::imwrite("/home/xlei/test_ccc.jpg", b_image); | |||
| cv::Mat rgba_mat; | |||
| cv::cvtColor(image, rgba_mat, CV_BGR2RGBA); | |||
| LiteMat lite_mat_bgr; | |||
| InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr); | |||
| bool ret = | |||
| InitFromPixel(rgba_mat.data, LPixelType::RGBA2BGR, LDataType::UINT8, rgba_mat.cols, rgba_mat.rows, lite_mat_bgr); | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat lite_mat_resize; | |||
| ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256); | |||
| ret = ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256); | |||
| ASSERT_TRUE(ret == true); | |||
| LiteMat makeborder; | |||
| Pad(lite_mat_resize, makeborder, top, bottom, left, right, PaddBorderType::PADD_BORDER_CONSTANT, 255, 255, 255); | |||
| ret = Pad(lite_mat_resize, makeborder, top, bottom, left, right, PaddBorderType::PADD_BORDER_CONSTANT, 255, 255, 255); | |||
| ASSERT_TRUE(ret == true); | |||
| cv::Mat dst_image(256 + top + bottom, 256 + left + right, CV_8UC3, makeborder.data_ptr_); | |||
| // cv::imwrite("/home/xlei/test_liteccc.jpg", dst_image); | |||
| } | |||
| TEST_F(MindDataImageProcess, TestGetDefaultBoxes) { | |||