|
|
|
@@ -142,8 +142,29 @@ def test_random_posterize_exception_bit(): |
|
|
|
logger.info("Got an exception in DE: {}".format(str(e))) |
|
|
|
assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2." |
|
|
|
|
|
|
|
def test_rescale_with_random_posterize(): |
|
|
|
""" |
|
|
|
Test RandomPosterize: only support CV_8S/CV_8U |
|
|
|
""" |
|
|
|
logger.info("test_rescale_with_random_posterize") |
|
|
|
|
|
|
|
DATA_DIR_10 = "../data/dataset/testCifar10Data" |
|
|
|
dataset = ds.Cifar10Dataset(DATA_DIR_10) |
|
|
|
|
|
|
|
rescale_op = c_vision.Rescale((1.0 / 255.0), 0.0) |
|
|
|
dataset = dataset.map(input_columns=["image"], operations=rescale_op) |
|
|
|
|
|
|
|
random_posterize_op = c_vision.RandomPosterize((4, 8)) |
|
|
|
dataset = dataset.map(input_columns=["image"], operations=random_posterize_op, num_parallel_workers=1) |
|
|
|
|
|
|
|
try: |
|
|
|
_ = dataset.output_shapes() |
|
|
|
except RuntimeError as e: |
|
|
|
logger.info("Got an exception in DE: {}".format(str(e))) |
|
|
|
assert "Input image data type can not be float" in str(e) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
skip_test_random_posterize_op_c(plot=True) |
|
|
|
skip_test_random_posterize_op_fixed_point_c(plot=True) |
|
|
|
test_random_posterize_exception_bit() |
|
|
|
test_rescale_with_random_posterize() |