Browse Source

!21357 [MD] fix SlicePatches when dealing with some special tensor.

Merge pull request !21357 from liyong126/fix_bug
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
51a1ca61f2
18 changed files with 104 additions and 38 deletions
  1. +14
    -6
      mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc
  2. +2
    -1
      mindspore/ccsrc/minddata/dataset/core/cv_tensor.h
  3. +1
    -1
      mindspore/ccsrc/minddata/dataset/core/data_type.cc
  4. +1
    -1
      mindspore/ccsrc/minddata/dataset/core/de_tensor.cc
  5. +10
    -9
      mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc
  6. +2
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/posterize_op.cc
  7. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc
  8. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/sharpness_op.cc
  9. +2
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_random_crop_resize_jpeg_op.cc
  10. +2
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_resize_jpeg_op.cc
  11. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.cc
  12. +2
    -2
      tests/ut/cpp/dataset/common/bboxop_common.cc
  13. +1
    -1
      tests/ut/cpp/dataset/common/cvop_common.cc
  14. +2
    -2
      tests/ut/cpp/dataset/random_color_op_test.cc
  15. +1
    -1
      tests/ut/cpp/dataset/rgba_to_bgr_op_test.cc
  16. +1
    -1
      tests/ut/cpp/dataset/rgba_to_rgb_op_test.cc
  17. +8
    -7
      tests/ut/cpp/dataset/tensor_test.cc
  18. +52
    -0
      tests/ut/python/dataset/test_slice_patches.py

+ 14
- 6
mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc View File

@@ -40,12 +40,21 @@ Status CVTensor::CreateEmpty(const TensorShape &shape, DataType type, CVTensorPt
return (*out)->MatInit((*out)->GetMutableBuffer(), (*out)->shape_, (*out)->type_, &(*out)->mat_);
}

Status CVTensor::CreateFromMat(const cv::Mat &mat, CVTensorPtr *out) {
Status CVTensor::CreateFromMat(const cv::Mat &mat, const dsize_t rank, CVTensorPtr *out) {
TensorPtr out_tensor;
cv::Mat mat_local = mat;
// if the input Mat's memory is not continuous, copy it to one block of memory
if (!mat.isContinuous()) mat_local = mat.clone();
TensorShape shape(mat.size, mat_local.type());
if (!mat.isContinuous()) {
mat_local = mat.clone();
}
TensorShape shape({});
if (mat.dims == 2 && rank == 2) {
shape = TensorShape({mat.rows, mat.cols});
} else if (mat.dims == 2 && rank == 3) {
shape = TensorShape({mat.rows, mat.cols, mat.channels()});
} else {
RETURN_STATUS_UNEXPECTED("Error in creating CVTensor: Invalid input rank or cv::mat dimension.");
}
DataType type = DataType::FromCVType(mat_local.type());
RETURN_IF_NOT_OK(CreateFromMemory(shape, type, mat_local.data, &out_tensor));
*out = AsCVTensor(out_tensor);
@@ -55,14 +64,13 @@ Status CVTensor::CreateFromMat(const cv::Mat &mat, CVTensorPtr *out) {
std::pair<std::array<int, 2>, int> CVTensor::IsValidImage(const TensorShape &shape, const DataType &type) {
std::array<int, 2> size = {1, 1};
if (shape.Rank() <= 2 || (shape.Rank() == 3 && shape[2] <= CV_CN_MAX)) {
uint8_t ch = 1;
uint16_t ch = 1;
if (shape.Rank() == 3) {
ch = static_cast<uint8_t>(shape[2]);
ch = static_cast<uint16_t>(shape[2]);
}
if (shape.Rank() > 0) size[0] = static_cast<int>(shape[0]);
if (shape.Rank() > 1) size[1] = static_cast<int>(shape[1]);
if (type.AsCVType() == kCVInvalidType) return std::make_pair(size, -1);

int cv_type = CV_MAKETYPE(type.AsCVType(), ch);
return std::make_pair(size, cv_type);
}


+ 2
- 1
mindspore/ccsrc/minddata/dataset/core/cv_tensor.h View File

@@ -53,9 +53,10 @@ class CVTensor : public Tensor {
/// Create CV tensor from cv::Mat
/// \note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it.
/// \param mat [in] cv::Mat to be copied into the new tensor.
/// \param shape [in] the rank of output CVTensor.
/// \param out [out] Generated tensor
/// \return Status code
static Status CreateFromMat(const cv::Mat &mat, CVTensorPtr *out);
static Status CreateFromMat(const cv::Mat &mat, const dsize_t rank, CVTensorPtr *out);

~CVTensor() override = default;



+ 1
- 1
mindspore/ccsrc/minddata/dataset/core/data_type.cc View File

@@ -61,7 +61,7 @@ uint8_t DataType::AsCVType() const {
}

return res;
} // namespace dataset
}

DataType DataType::FromCVType(int cv_type) {
auto depth = static_cast<uchar>(cv_type) & static_cast<uchar>(CV_MAT_DEPTH_MASK);


+ 1
- 1
mindspore/ccsrc/minddata/dataset/core/de_tensor.cc View File

@@ -76,7 +76,7 @@ size_t DETensor::DataSize() const {
}
#endif
EXCEPTION_IF_NULL(tensor_impl_);
return static_cast<uint32_t>(tensor_impl_->SizeInBytes());
return static_cast<size_t>(tensor_impl_->SizeInBytes());
}

const std::vector<int64_t> &DETensor::Shape() const { return shape_; }


+ 10
- 9
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc View File

@@ -189,7 +189,7 @@ Status DecodeCv(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
}
cv::cvtColor(img_mat, img_mat, static_cast<int>(cv::COLOR_BGR2RGB));
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(img_mat, &output_cv));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(img_mat, 3, &output_cv));
*output = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
@@ -600,7 +600,7 @@ Status CropAndResize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
if (mode == InterpolationMode::kCubicPil) {
cv::Mat input_roi = cv_in(roi);
std::shared_ptr<CVTensor> input_image;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_roi, &input_image));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_roi, input_cv->Rank(), &input_image));
LiteMat imIn, imOut;
std::shared_ptr<Tensor> output_tensor;
TensorShape new_shape = TensorShape({target_height, target_width, 3});
@@ -676,7 +676,7 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
// use memcpy and don't compute the new shape since openCV has a rounding problem
cv::warpAffine(input_img, output_img, rot, bbox.size(), GetCVInterpolationMode(interpolation),
cv::BORDER_CONSTANT, fill_color);
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &output_cv));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, input_cv->Rank(), &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);
}
*output = std::static_pointer_cast<Tensor>(output_cv);
@@ -999,7 +999,7 @@ Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
cv::merge(image_result, result);
result.convertTo(result, input_cv->mat().type());
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, input_cv->Rank(), &output_cv));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
RETURN_IF_NOT_OK((*output)->Reshape(input_cv->shape()));
} catch (const cv::Exception &e) {
@@ -1100,7 +1100,7 @@ Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
cv::Mat result;
cv::merge(image_result, result);
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, input_cv->Rank(), &output_cv));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
RETURN_IF_NOT_OK((*output)->Reshape(input_cv->shape()));
} catch (const cv::Exception &e) {
@@ -1196,7 +1196,7 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type);
}
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_image, &output_cv));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_image, input_cv->Rank(), &output_cv));
// pad the dimension if shape information is only 2 dimensional, this is grayscale
int num_channels = input_cv->shape()[CHANNEL_INDEX];
if (input_cv->Rank() == DEFAULT_IMAGE_RANK && num_channels == MIN_IMAGE_CHANNELS &&
@@ -1341,7 +1341,7 @@ Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
cv::GaussianBlur(input_cv->mat(), output_cv_mat, cv::Size(kernel_x, kernel_y), static_cast<double>(sigma_x),
static_cast<double>(sigma_y));
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_cv_mat, &output_cv));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_cv_mat, input_cv->Rank(), &output_cv));
(*output) = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
@@ -1414,8 +1414,9 @@ Status SlicePatches(const std::shared_ptr<Tensor> &input, std::vector<std::share
for (int i = 0; i < num_height; ++i) {
for (int j = 0; j < num_width; ++j) {
std::shared_ptr<CVTensor> patch_cv;
cv::Rect patch(j * patch_w, i * patch_h, patch_w, patch_h);
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_img(patch), &patch_cv));
cv::Rect rect(j * patch_w, i * patch_h, patch_w, patch_h);
cv::Mat patch(out_img(rect));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(patch, input_cv->Rank(), &patch_cv));
(*output).push_back(std::static_pointer_cast<Tensor>(patch_cv));
}
}


+ 2
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/posterize_op.cc View File

@@ -46,7 +46,8 @@ Status PosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
input->type().ToString());
cv::LUT(in_image, lut_vector, output_img);
std::shared_ptr<CVTensor> result_tensor;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor));

RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, input_cv->Rank(), &result_tensor));
*output = std::static_pointer_cast<Tensor>(result_tensor);
return Status::OK();
}


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc View File

@@ -46,7 +46,7 @@ Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr
cv::Mat cv_out;
cv::merge(temp, 3, cv_out);
std::shared_ptr<CVTensor> cvt_out;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(cv_out, &cvt_out));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(cv_out, cvt_in->Rank(), &cvt_out));
if (abs(t - 0.0) < eps) {
// return grayscale
*out = std::static_pointer_cast<Tensor>(cvt_out);


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/sharpness_op.cc View File

@@ -63,7 +63,7 @@ Status SharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
cv::addWeighted(input_img, alpha_, result, 1.0 - alpha_, 0.0, result);

std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, input_cv->Rank(), &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);

*output = std::static_pointer_cast<Tensor>(output_cv);


+ 2
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_random_crop_resize_jpeg_op.cc View File

@@ -74,7 +74,8 @@ Status SoftDvppDecodeRandomCropResizeJpegOp::Compute(const std::shared_ptr<Tenso
error_info += std::to_string(ret) + ", please check the log information for more details.";
CHECK_FAIL_RETURN_UNEXPECTED(ret == 0, error_info);
std::shared_ptr<CVTensor> cv_tensor = nullptr;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_rgb_img, &cv_tensor));

RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_rgb_img, 3, &cv_tensor));
*output = std::static_pointer_cast<Tensor>(cv_tensor);
} catch (const cv::Exception &e) {
std::string error = "SoftDvppDecodeRandomCropResizeJpeg:" + std::string(e.what());


+ 2
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_resize_jpeg_op.cc View File

@@ -66,7 +66,8 @@ Status SoftDvppDecodeResizeJpegOp::Compute(const std::shared_ptr<Tensor> &input,
error_info += std::to_string(ret) + ", please check the log information for more details.";
CHECK_FAIL_RETURN_UNEXPECTED(ret == 0, error_info);
std::shared_ptr<CVTensor> cv_tensor = nullptr;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_rgb_img, &cv_tensor));

RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_rgb_img, 3, &cv_tensor));
*output = std::static_pointer_cast<Tensor>(cv_tensor);
} catch (const cv::Exception &e) {
std::string error = "SoftDvppDecodeResizeJpeg:" + std::string(e.what());


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/solarize_op.cc View File

@@ -41,7 +41,7 @@ Status SolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr

std::shared_ptr<CVTensor> mask_mat_tensor;
std::shared_ptr<CVTensor> output_cv_tensor;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_cv->mat(), &mask_mat_tensor));
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_img, input_cv->Rank(), &mask_mat_tensor));

RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv_tensor));
RETURN_UNEXPECTED_IF_NULL(mask_mat_tensor);


+ 2
- 2
tests/ut/cpp/dataset/common/bboxop_common.cc View File

@@ -164,8 +164,8 @@ void BBoxOpCommon::CompareActualAndExpected(const std::string &op_name) {
EXPECT_TRUE(remove(actual_path.c_str()) == 0);
// compare using ==operator by Tensor
std::shared_ptr<CVTensor> expect_img_t, actual_img_t;
CVTensor::CreateFromMat(expect_img, &expect_img_t);
CVTensor::CreateFromMat(actual_img, &actual_img_t);
CVTensor::CreateFromMat(expect_img, 3, &expect_img_t);
CVTensor::CreateFromMat(actual_img, 3, &actual_img_t);
if (actual_img.data) {
EXPECT_EQ(*expect_img_t == *actual_img_t, true);
} else {


+ 1
- 1
tests/ut/cpp/dataset/common/cvop_common.cc View File

@@ -55,7 +55,7 @@ void CVOpCommon::GetInputImage(std::string filename) {
Tensor::CreateFromFile(filename, &raw_input_tensor_);
raw_cv_image_ = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
std::shared_ptr<CVTensor> input_cv_tensor;
CVTensor::CreateFromMat(raw_cv_image_, &input_cv_tensor);
CVTensor::CreateFromMat(raw_cv_image_, 3, &input_cv_tensor);
input_tensor_ = std::dynamic_pointer_cast<Tensor>(input_cv_tensor);
SwapRedAndBlue(input_tensor_, &input_tensor_);
if (raw_cv_image_.data) {


+ 2
- 2
tests/ut/cpp/dataset/random_color_op_test.cc View File

@@ -43,7 +43,7 @@ class MindDataTestRandomColorOp : public UT::CVOP::CVOpCommon {
cv::Mat cv_out;
cv::merge(temp, 3, cv_out);
std::shared_ptr<CVTensor> cvt_out;
CVTensor::CreateFromMat(cv_out, &cvt_out);
CVTensor::CreateFromMat(cv_out, 3, &cvt_out);
gray_tensor = std::static_pointer_cast<Tensor>(cvt_out);
}
TensorShape shape;
@@ -96,4 +96,4 @@ TEST_F(MindDataTestRandomColorOp, TestOp3) {
auto s = op.Compute(input_tensor, &output_tensor);
EXPECT_TRUE(s.IsOk());
}
}
}

+ 1
- 1
tests/ut/cpp/dataset/rgba_to_bgr_op_test.cc View File

@@ -48,7 +48,7 @@ TEST_F(MindDataTestRgbaToBgrOp, TestOp1) {
// create new tensor to test conversion
std::shared_ptr<Tensor> rgba_input;
std::shared_ptr<CVTensor> input_cv_tensor;
CVTensor::CreateFromMat(rgba_image, &input_cv_tensor);
CVTensor::CreateFromMat(rgba_image, 3, &input_cv_tensor);
rgba_input = std::dynamic_pointer_cast<Tensor>(input_cv_tensor);

Status s = op->Compute(rgba_input, &output_tensor_);


+ 1
- 1
tests/ut/cpp/dataset/rgba_to_rgb_op_test.cc View File

@@ -48,7 +48,7 @@ TEST_F(MindDataTestRgbaToRgbOp, TestOp1) {
// create new tensor to test conversion
std::shared_ptr<Tensor> rgba_input;
std::shared_ptr<CVTensor> input_cv_tensor;
CVTensor::CreateFromMat(rgba_image, &input_cv_tensor);
CVTensor::CreateFromMat(rgba_image, 3, &input_cv_tensor);
rgba_input = std::dynamic_pointer_cast<Tensor>(input_cv_tensor);

Status s = op->Compute(rgba_input, &output_tensor_);


+ 8
- 7
tests/ut/cpp/dataset/tensor_test.cc View File

@@ -303,7 +303,8 @@ TEST_F(MindDataTestTensorDE, CVTensorFromMat) {
m.at<uint8_t>(1, 0) = 30;
m.at<uint8_t>(1, 1) = 40;
std::shared_ptr<CVTensor> cvt;
CVTensor::CreateFromMat(m, &cvt);
TensorShape shape{2, 2};
CVTensor::CreateFromMat(m, 2, &cvt);
std::shared_ptr<Tensor> t;
Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t);
t->SetItemAt<uint8_t>({0, 0}, 10);
@@ -318,7 +319,7 @@ TEST_F(MindDataTestTensorDE, CVTensorFromMat) {
m2.at<uint8_t>(2) = 30;
m2.at<uint8_t>(3) = 40;
std::shared_ptr<CVTensor> cvt2;
CVTensor::CreateFromMat(m2, &cvt2);
CVTensor::CreateFromMat(m2, 2, &cvt2);
std::shared_ptr<Tensor> t2;
Tensor::CreateEmpty(TensorShape({4}), DataType(DataType::DE_UINT8), &t2);
t2->SetItemAt<uint8_t>({0}, 10);
@@ -360,7 +361,7 @@ TEST_F(MindDataTestTensorDE, CVTensorMatSlice) {
m.at<int32_t>(1, 1) = 50;
m.at<int32_t>(1, 2) = 60;
std::shared_ptr<CVTensor> cvt;
CVTensor::CreateFromMat(m, &cvt);
CVTensor::CreateFromMat(m, 2, &cvt);
cv::Mat mat;
cvt->MatAtIndex({1}, &mat);
cv::Mat m2(3, 1, CV_32S);
@@ -368,17 +369,17 @@ TEST_F(MindDataTestTensorDE, CVTensorMatSlice) {
m2.at<int32_t>(1) = 50;
m2.at<int32_t>(2) = 60;
std::shared_ptr<CVTensor> cvt2;
CVTensor::CreateFromMat(mat, &cvt2);
CVTensor::CreateFromMat(mat, 2, &cvt2);
std::shared_ptr<CVTensor> cvt3;
CVTensor::CreateFromMat(m2, &cvt3);
CVTensor::CreateFromMat(m2, 2, &cvt3);

ASSERT_TRUE(*cvt2 == *cvt3);
cvt->MatAtIndex({0}, &mat);
m2.at<int32_t>(0) = 10;
m2.at<int32_t>(1) = 20;
m2.at<int32_t>(2) = 30;
CVTensor::CreateFromMat(mat, &cvt2);
CVTensor::CreateFromMat(m2, &cvt3);
CVTensor::CreateFromMat(mat, 2, &cvt2);
CVTensor::CreateFromMat(m2, 2, &cvt3);
ASSERT_TRUE(*cvt2 == *cvt3);
}



+ 52
- 0
tests/ut/python/dataset/test_slice_patches.py View File

@@ -140,6 +140,54 @@ def test_slice_patches_exception_01():
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
assert "Input fill_value is not within" in str(e)

def test_slice_patches_06():
image = np.random.randint(0, 255, (158, 126, 1)).astype(np.int32)
slice_patches_op = c_vision.SlicePatches(2, 8)
patches = slice_patches_op(image)
assert len(patches) == 16
assert patches[0].shape == (79, 16, 1)

def test_slice_patches_07():
image = np.random.randint(0, 255, (158, 126)).astype(np.int32)
slice_patches_op = c_vision.SlicePatches(2, 8)
patches = slice_patches_op(image)
assert len(patches) == 16
assert patches[0].shape == (79, 16)

def test_slice_patches_08():
np_data = np.random.randint(0, 255, (1, 56, 82, 256)).astype(np.uint8)
dataset = ds.NumpySlicesDataset(np_data, column_names=["image"])
slice_patches_op = c_vision.SlicePatches(2, 2)
dataset = dataset.map(input_columns=["image"], output_columns=["img0", "img1", "img2", "img3"],
column_order=["img0", "img1", "img2", "img3"],
operations=slice_patches_op)
for item in dataset.create_dict_iterator(output_numpy=True):
patch_shape = item['img0'].shape
assert patch_shape == (28, 41, 256)

def test_slice_patches_09():
image = np.random.randint(0, 255, (56, 82, 256)).astype(np.uint8)
slice_patches_op = c_vision.SlicePatches(4, 3, mode.SliceMode.PAD)
patches = slice_patches_op(image)
assert len(patches) == 12
assert patches[0].shape == (14, 28, 256)

def skip_test_slice_patches_10():
image = np.random.randint(0, 255, (7000, 7000, 255)).astype(np.uint8)
slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP)
patches = slice_patches_op(image)
assert patches[0].shape == (700, 538, 255)

def skip_test_slice_patches_11():
np_data = np.random.randint(0, 255, (1, 7000, 7000, 256)).astype(np.uint8)
dataset = ds.NumpySlicesDataset(np_data, column_names=["image"])
slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP)
cols = ['img' + str(x) for x in range(10*13)]
dataset = dataset.map(input_columns=["image"], output_columns=cols,
column_order=cols, operations=slice_patches_op)
for item in dataset.create_dict_iterator(output_numpy=True):
patch_shape = item['img0'].shape
assert patch_shape == (700, 538, 256)

def slice_patches(image, num_h, num_w, pad_or_drop, fill_value):
""" help function which slice patches with numpy """
@@ -174,4 +222,8 @@ if __name__ == "__main__":
test_slice_patches_03(plot=True)
test_slice_patches_04(plot=True)
test_slice_patches_05(plot=True)
test_slice_patches_06()
test_slice_patches_07()
test_slice_patches_08()
test_slice_patches_09()
test_slice_patches_exception_01()

Loading…
Cancel
Save