Browse Source

Check input image type for random posterize

tags/v1.0.0
YangLuo 5 years ago
parent
commit
67f5c89cd6
2 changed files with 23 additions and 0 deletions
  1. +2
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/posterize_op.cc
  2. +21
    -0
      tests/ut/python/dataset/test_random_posterize.py

+ 2
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/posterize_op.cc View File

@@ -40,6 +40,8 @@ Status PosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
} }
cv::Mat in_image = input_cv->mat(); cv::Mat in_image = input_cv->mat();
cv::Mat output_img; cv::Mat output_img;
CHECK_FAIL_RETURN_UNEXPECTED(in_image.depth() == CV_8U || in_image.depth() == CV_8S,
"Input image data type can not be float, but got " + input->type().ToString());
cv::LUT(in_image, lut_vector, output_img); cv::LUT(in_image, lut_vector, output_img);
std::shared_ptr<CVTensor> result_tensor; std::shared_ptr<CVTensor> result_tensor;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor)); RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor));


+ 21
- 0
tests/ut/python/dataset/test_random_posterize.py View File

@@ -142,8 +142,29 @@ def test_random_posterize_exception_bit():
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2." assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2."


def test_rescale_with_random_posterize():
"""
Test RandomPosterize: only support CV_8S/CV_8U
"""
logger.info("test_rescale_with_random_posterize")

DATA_DIR_10 = "../data/dataset/testCifar10Data"
dataset = ds.Cifar10Dataset(DATA_DIR_10)

rescale_op = c_vision.Rescale((1.0 / 255.0), 0.0)
dataset = dataset.map(input_columns=["image"], operations=rescale_op)

random_posterize_op = c_vision.RandomPosterize((4, 8))
dataset = dataset.map(input_columns=["image"], operations=random_posterize_op, num_parallel_workers=1)

try:
_ = dataset.output_shapes()
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input image data type can not be float" in str(e)


if __name__ == "__main__": if __name__ == "__main__":
skip_test_random_posterize_op_c(plot=True) skip_test_random_posterize_op_c(plot=True)
skip_test_random_posterize_op_fixed_point_c(plot=True) skip_test_random_posterize_op_fixed_point_c(plot=True)
test_random_posterize_exception_bit() test_random_posterize_exception_bit()
test_rescale_with_random_posterize()

Loading…
Cancel
Save