Merge pull request !1908 from Tinazhang/ut-normalizetags/v0.5.0-beta
| @@ -141,9 +141,9 @@ def test_crop_grayscale(height=375, width=375): | |||
| if __name__ == "__main__": | |||
| test_center_crop_op(600, 600, True) | |||
| test_center_crop_op(600, 600, plot=True) | |||
| test_center_crop_op(300, 600) | |||
| test_center_crop_op(600, 300) | |||
| test_center_crop_md5() | |||
| test_center_crop_comp(True) | |||
| test_center_crop_comp(plot=True) | |||
| test_crop_grayscale() | |||
| @@ -12,26 +12,54 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing Normalize op in DE | |||
| """ | |||
| import numpy as np | |||
| import matplotlib.pyplot as plt | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| import mindspore.dataset.transforms.vision.py_transforms as py_vision | |||
| from mindspore import log as logger | |||
| from util import diff_mse, save_and_check_md5 | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| GENERATE_GOLDEN = False | |||
| def visualize_mse(image_de_normalized, image_np_normalized, mse, image_original): | |||
| """ | |||
| visualizes the image using DE op and Numpy op | |||
| """ | |||
| plt.subplot(141) | |||
| plt.imshow(image_original) | |||
| plt.title("Original image") | |||
| plt.subplot(142) | |||
| plt.imshow(image_de_normalized) | |||
| plt.title("DE normalized image") | |||
| plt.subplot(143) | |||
| plt.imshow(image_np_normalized) | |||
| plt.title("Numpy normalized image") | |||
| plt.subplot(144) | |||
| plt.imshow(image_de_normalized - image_np_normalized) | |||
| plt.title("Difference image, mse : {}".format(mse)) | |||
| plt.show() | |||
| def normalize_np(image): | |||
| def normalize_np(image, mean, std): | |||
| """ | |||
| Apply the normalization | |||
| """ | |||
| # DE decodes the image in RGB by deafult, hence | |||
| # the values here are in RGB | |||
| image = np.array(image, np.float32) | |||
| image = image - np.array([121.0, 115.0, 100.0]) | |||
| image = image * (1.0 / np.array([70.0, 68.0, 71.0])) | |||
| image = image - np.array(mean) | |||
| image = image * (1.0 / np.array(std)) | |||
| return image | |||
| @@ -41,7 +69,7 @@ def get_normalized(image_id): | |||
| Reads the image using DE ops and then normalizes using Numpy | |||
| """ | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| decode_op = vision.Decode() | |||
| decode_op = c_vision.Decode() | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| @@ -52,15 +80,61 @@ def get_normalized(image_id): | |||
| return None | |||
| def test_normalize_op(): | |||
| def util_test_normalize(mean, std, op_type): | |||
| """ | |||
| Test Normalize | |||
| Utility function for testing Normalize. Input arguments are given by other tests | |||
| """ | |||
| logger.info("Test Normalize") | |||
| if op_type == "cpp": | |||
| # define map operations | |||
| decode_op = c_vision.Decode() | |||
| normalize_op = c_vision.Normalize(mean, std) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=decode_op) | |||
| data = data.map(input_columns=["image"], operations=normalize_op) | |||
| elif op_type == "python": | |||
| # define map operations | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.ToTensor(), | |||
| py_vision.Normalize(mean, std) | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| else: | |||
| raise ValueError("Wrong parameter value") | |||
| return data | |||
| def util_test_normalize_grayscale(num_output_channels, mean, std): | |||
| """ | |||
| Utility function for testing Normalize. Input arguments are given by other tests | |||
| """ | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.Grayscale(num_output_channels), | |||
| py_vision.ToTensor(), | |||
| py_vision.Normalize(mean, std) | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| return data | |||
| def test_normalize_op_c(plot=False): | |||
| """ | |||
| Test Normalize in cpp transformations | |||
| """ | |||
| logger.info("Test Normalize in cpp") | |||
| mean = [121.0, 115.0, 100.0] | |||
| std = [70.0, 68.0, 71.0] | |||
| # define map operations | |||
| decode_op = vision.Decode() | |||
| normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0]) | |||
| decode_op = c_vision.Decode() | |||
| normalize_op = c_vision.Normalize(mean, std) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| @@ -74,36 +148,64 @@ def test_normalize_op(): | |||
| num_iter = 0 | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image_de_normalized = item1["image"] | |||
| image_np_normalized = normalize_np(item2["image"]) | |||
| diff = image_de_normalized - image_np_normalized | |||
| mse = np.sum(np.power(diff, 2)) | |||
| image_original = item2["image"] | |||
| image_np_normalized = normalize_np(image_original, mean, std) | |||
| mse = diff_mse(image_de_normalized, image_np_normalized) | |||
| logger.info("image_{}, mse: {}".format(num_iter + 1, mse)) | |||
| assert mse < 0.01 | |||
| # Uncomment these blocks to see visual results | |||
| # plt.subplot(131) | |||
| # plt.imshow(image_de_normalized) | |||
| # plt.title("DE normalize image") | |||
| # | |||
| # plt.subplot(132) | |||
| # plt.imshow(image_np_normalized) | |||
| # plt.title("Numpy normalized image") | |||
| # | |||
| # plt.subplot(133) | |||
| # plt.imshow(diff) | |||
| # plt.title("Difference image, mse : {}".format(mse)) | |||
| # | |||
| # plt.show() | |||
| if plot: | |||
| visualize_mse(image_de_normalized, image_np_normalized, mse, image_original) | |||
| num_iter += 1 | |||
| def test_normalize_op_py(plot=False): | |||
| """ | |||
| Test Normalize in python transformations | |||
| """ | |||
| logger.info("Test Normalize in python") | |||
| mean = [0.475, 0.45, 0.392] | |||
| std = [0.275, 0.267, 0.278] | |||
| # define map operations | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| normalize_op = py_vision.Normalize(mean, std) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(input_columns=["image"], operations=transform()) | |||
| data1 = data1.map(input_columns=["image"], operations=normalize_op) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(input_columns=["image"], operations=transform()) | |||
| num_iter = 0 | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image_de_normalized = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_np_normalized = (normalize_np(item2["image"].transpose(1, 2, 0), mean, std) * 255).astype(np.uint8) | |||
| image_original = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| mse = diff_mse(image_de_normalized, image_np_normalized) | |||
| logger.info("image_{}, mse: {}".format(num_iter + 1, mse)) | |||
| assert mse < 0.01 | |||
| if plot: | |||
| visualize_mse(image_de_normalized, image_np_normalized, mse, image_original) | |||
| num_iter += 1 | |||
| def test_decode_op(): | |||
| """ | |||
| Test Decode op | |||
| """ | |||
| logger.info("Test Decode") | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], num_parallel_workers=1, | |||
| shuffle=False) | |||
| # define map operations | |||
| decode_op = vision.Decode() | |||
| decode_op = c_vision.Decode() | |||
| # apply map operations on images | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| @@ -112,22 +214,21 @@ def test_decode_op(): | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("Looping inside iterator {}".format(num_iter)) | |||
| _ = item["image"] | |||
| # plt.subplot(131) | |||
| # plt.imshow(image) | |||
| # plt.title("DE image") | |||
| # plt.show() | |||
| num_iter += 1 | |||
| def test_decode_normalize_op(): | |||
| """ | |||
| Test Decode op followed by Normalize op | |||
| """ | |||
| logger.info("Test [Decode, Normalize] in one Map") | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], num_parallel_workers=1, | |||
| shuffle=False) | |||
| # define map operations | |||
| decode_op = vision.Decode() | |||
| normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0]) | |||
| decode_op = c_vision.Decode() | |||
| normalize_op = c_vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0]) | |||
| # apply map operations on images | |||
| data1 = data1.map(input_columns=["image"], operations=[decode_op, normalize_op]) | |||
| @@ -136,14 +237,139 @@ def test_decode_normalize_op(): | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("Looping inside iterator {}".format(num_iter)) | |||
| _ = item["image"] | |||
| # plt.subplot(131) | |||
| # plt.imshow(image) | |||
| # plt.title("DE image") | |||
| # plt.show() | |||
| num_iter += 1 | |||
| def test_normalize_md5_01(): | |||
| """ | |||
| Test Normalize with md5 check: valid mean and std | |||
| expected to pass | |||
| """ | |||
| logger.info("test_normalize_md5_01") | |||
| data_c = util_test_normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0], "cpp") | |||
| data_py = util_test_normalize([0.475, 0.45, 0.392], [0.275, 0.267, 0.278], "python") | |||
| # check results with md5 comparison | |||
| filename1 = "normalize_01_c_result.npz" | |||
| filename2 = "normalize_01_py_result.npz" | |||
| save_and_check_md5(data_c, filename1, generate_golden=GENERATE_GOLDEN) | |||
| save_and_check_md5(data_py, filename2, generate_golden=GENERATE_GOLDEN) | |||
| def test_normalize_md5_02(): | |||
| """ | |||
| Test Normalize with md5 check: len(mean)=len(std)=1 with RGB images | |||
| expected to pass | |||
| """ | |||
| logger.info("test_normalize_md5_02") | |||
| data_py = util_test_normalize([0.475], [0.275], "python") | |||
| # check results with md5 comparison | |||
| filename2 = "normalize_02_py_result.npz" | |||
| save_and_check_md5(data_py, filename2, generate_golden=GENERATE_GOLDEN) | |||
| def test_normalize_exception_unequal_size_c(): | |||
| """ | |||
| Test Normalize in c transformation: len(mean) != len(std) | |||
| expected to raise ValueError | |||
| """ | |||
| logger.info("test_normalize_exception_unequal_size_c") | |||
| try: | |||
| _ = c_vision.Normalize([100, 250, 125], [50, 50, 75, 75]) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "Length of mean and std must be equal" | |||
| def test_normalize_exception_unequal_size_py(): | |||
| """ | |||
| Test Normalize in python transformation: len(mean) != len(std) | |||
| expected to raise ValueError | |||
| """ | |||
| logger.info("test_normalize_exception_unequal_size_py") | |||
| try: | |||
| _ = py_vision.Normalize([0.50, 0.30, 0.75], [0.18, 0.32, 0.71, 0.72]) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "Length of mean and std must be equal" | |||
| def test_normalize_exception_invalid_size_py(): | |||
| """ | |||
| Test Normalize in python transformation: len(mean)=len(std)=2 | |||
| expected to raise RuntimeError | |||
| """ | |||
| logger.info("test_normalize_exception_invalid_size_py") | |||
| data = util_test_normalize([0.75, 0.25], [0.18, 0.32], "python") | |||
| try: | |||
| _ = data.create_dict_iterator().get_next() | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Length of mean and std must both be 1 or" in str(e) | |||
| def test_normalize_exception_invalid_range_py(): | |||
| """ | |||
| Test Normalize in python transformation: value is not in range [0,1] | |||
| expected to raise ValueError | |||
| """ | |||
| logger.info("test_normalize_exception_invalid_range_py") | |||
| try: | |||
| _ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not within the required range" in str(e) | |||
| def test_normalize_grayscale_md5_01(): | |||
| """ | |||
| Test Normalize with md5 check: len(mean)=len(std)=1 with 1 channel grayscale images | |||
| expected to pass | |||
| """ | |||
| logger.info("test_normalize_grayscale_md5_01") | |||
| data = util_test_normalize_grayscale(1, [0.5], [0.175]) | |||
| # check results with md5 comparison | |||
| filename = "normalize_03_py_result.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_normalize_grayscale_md5_02(): | |||
| """ | |||
| Test Normalize with md5 check: len(mean)=len(std)=3 with 3 channel grayscale images | |||
| expected to pass | |||
| """ | |||
| logger.info("test_normalize_grayscale_md5_02") | |||
| data = util_test_normalize_grayscale(3, [0.5, 0.5, 0.5], [0.175, 0.235, 0.512]) | |||
| # check results with md5 comparison | |||
| filename = "normalize_04_py_result.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_normalize_grayscale_exception(): | |||
| """ | |||
| Test Normalize: len(mean)=len(std)=3 with 1 channel grayscale images | |||
| expected to raise RuntimeError | |||
| """ | |||
| logger.info("test_normalize_grayscale_exception") | |||
| try: | |||
| _ = util_test_normalize_grayscale(1, [0.5, 0.5, 0.5], [0.175, 0.235, 0.512]) | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not within the required range" in str(e) | |||
| if __name__ == "__main__": | |||
| test_decode_op() | |||
| test_decode_normalize_op() | |||
| test_normalize_op() | |||
| test_normalize_op_c(plot=True) | |||
| test_normalize_op_py(plot=True) | |||
| test_normalize_md5_01() | |||
| test_normalize_md5_02() | |||
| test_normalize_exception_unequal_size_c() | |||
| test_normalize_exception_unequal_size_py() | |||
| test_normalize_exception_invalid_size_py() | |||
| test_normalize_exception_invalid_range_py() | |||
| test_normalize_grayscale_md5_01() | |||
| test_normalize_grayscale_md5_02() | |||
| test_normalize_grayscale_exception() | |||
| @@ -0,0 +1,207 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing RandomAffine op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.py_transforms as py_vision | |||
| from mindspore import log as logger | |||
| from util import visualize, save_and_check_md5, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_random_affine_op(plot=False): | |||
| """ | |||
| Test RandomAffine in python transformations | |||
| """ | |||
| logger.info("test_random_affine_op") | |||
| # define map operations | |||
| transforms1 = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform1 = py_vision.ComposeOp(transforms1) | |||
| transforms2 = [ | |||
| py_vision.Decode(), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform2 = py_vision.ComposeOp(transforms2) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(input_columns=["image"], operations=transform1()) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(input_columns=["image"], operations=transform2()) | |||
| image_affine = [] | |||
| image_original = [] | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_affine.append(image1) | |||
| image_original.append(image2) | |||
| if plot: | |||
| visualize(image_original, image_affine) | |||
| def test_random_affine_md5(): | |||
| """ | |||
| Test RandomAffine with md5 comparison | |||
| """ | |||
| logger.info("test_random_affine_md5") | |||
| original_seed = config_get_set_seed(55) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # define map operations | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomAffine(degrees=(-5, 15), translate=(0.1, 0.3), | |||
| scale=(0.9, 1.1), shear=(-10, 10, -5, 5)), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| # check results with md5 comparison | |||
| filename = "random_affine_01_result.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| # Restore configuration | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers((original_num_parallel_workers)) | |||
| def test_random_affine_exception_negative_degrees(): | |||
| """ | |||
| Test RandomAffine: input degrees in negative, expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_affine_exception_negative_degrees") | |||
| try: | |||
| _ = py_vision.RandomAffine(degrees=-15) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "If degrees is a single number, it cannot be negative." | |||
| def test_random_affine_exception_translation_range(): | |||
| """ | |||
| Test RandomAffine: translation value is not in [0, 1], expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_affine_exception_translation_range") | |||
| try: | |||
| _ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5)) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "translation values should be between 0 and 1" | |||
| def test_random_affine_exception_scale_value(): | |||
| """ | |||
| Test RandomAffine: scale is not positive, expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_affine_exception_scale_value") | |||
| try: | |||
| _ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1)) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "scale values should be positive" | |||
| def test_random_affine_exception_shear_value(): | |||
| """ | |||
| Test RandomAffine: shear is a number but is not positive, expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_affine_exception_shear_value") | |||
| try: | |||
| _ = py_vision.RandomAffine(degrees=15, shear=-5) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "If shear is a single number, it must be positive." | |||
| def test_random_affine_exception_degrees_size(): | |||
| """ | |||
| Test RandomAffine: degrees is a list or tuple and its length is not 2, | |||
| expected to raise TypeError | |||
| """ | |||
| logger.info("test_random_affine_exception_degrees_size") | |||
| try: | |||
| _ = py_vision.RandomAffine(degrees=[15]) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "If degrees is a sequence, the length must be 2." | |||
| def test_random_affine_exception_translate_size(): | |||
| """ | |||
| Test RandomAffine: translate is not list or a tuple of length 2, | |||
| expected to raise TypeError | |||
| """ | |||
| logger.info("test_random_affine_exception_translate_size") | |||
| try: | |||
| _ = py_vision.RandomAffine(degrees=15, translate=(0.1)) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "translate should be a list or tuple of length 2." | |||
| def test_random_affine_exception_scale_size(): | |||
| """ | |||
| Test RandomAffine: scale is not a list or tuple of length 2, | |||
| expected to raise TypeError | |||
| """ | |||
| logger.info("test_random_affine_exception_scale_size") | |||
| try: | |||
| _ = py_vision.RandomAffine(degrees=15, scale=(0.5)) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "scale should be a list or tuple of length 2." | |||
| def test_random_affine_exception_shear_size(): | |||
| """ | |||
| Test RandomAffine: shear is not a list or tuple of length 2 or 4, | |||
| expected to raise TypeError | |||
| """ | |||
| logger.info("test_random_affine_exception_shear_size") | |||
| try: | |||
| _ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10)) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4." | |||
| if __name__ == "__main__": | |||
| test_random_affine_op(plot=True) | |||
| test_random_affine_md5() | |||
| test_random_affine_exception_negative_degrees() | |||
| test_random_affine_exception_translation_range() | |||
| test_random_affine_exception_scale_value() | |||
| test_random_affine_exception_shear_value() | |||
| test_random_affine_exception_degrees_size() | |||
| test_random_affine_exception_translate_size() | |||
| test_random_affine_exception_scale_size() | |||
| test_random_affine_exception_shear_size() | |||
| @@ -0,0 +1,133 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing RandomApply op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.py_transforms as py_vision | |||
| from mindspore import log as logger | |||
| from util import visualize, config_get_set_seed, \ | |||
| config_get_set_num_parallel_workers, save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_random_apply_op(plot=False): | |||
| """ | |||
| Test RandomApply in python transformations | |||
| """ | |||
| logger.info("test_random_apply_op") | |||
| # define map operations | |||
| transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)] | |||
| transforms1 = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomApply(transforms_list, prob=0.6), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform1 = py_vision.ComposeOp(transforms1) | |||
| transforms2 = [ | |||
| py_vision.Decode(), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform2 = py_vision.ComposeOp(transforms2) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(input_columns=["image"], operations=transform1()) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(input_columns=["image"], operations=transform2()) | |||
| image_apply = [] | |||
| image_original = [] | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_apply.append(image1) | |||
| image_original.append(image2) | |||
| if plot: | |||
| visualize(image_original, image_apply) | |||
| def test_random_apply_md5(): | |||
| """ | |||
| Test RandomApply op with md5 check | |||
| """ | |||
| logger.info("test_random_apply_md5") | |||
| original_seed = config_get_set_seed(10) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # define map operations | |||
| transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)] | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| # Note: using default value "prob=0.5" | |||
| py_vision.RandomApply(transforms_list), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| # check results with md5 comparison | |||
| filename = "random_apply_01_result.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| # Restore configuration | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers((original_num_parallel_workers)) | |||
| def test_random_apply_exception_random_crop_badinput(): | |||
| """ | |||
| Test RandomApply: test invalid input for one of the transform functions, | |||
| expected to raise error | |||
| """ | |||
| logger.info("test_random_apply_exception_random_crop_badinput") | |||
| original_seed = config_get_set_seed(200) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # define map operations | |||
| transforms_list = [py_vision.Resize([32, 32]), | |||
| py_vision.RandomCrop(100), # crop size > image size | |||
| py_vision.RandomRotation(30)] | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomApply(transforms_list, prob=0.6), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| try: | |||
| _ = data.create_dict_iterator().get_next() | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Crop size" in str(e) | |||
| # Restore configuration | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| if __name__ == '__main__': | |||
| test_random_apply_op(plot=True) | |||
| test_random_apply_md5() | |||
| test_random_apply_exception_random_crop_badinput() | |||
| @@ -0,0 +1,136 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing RandomChoice op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.py_transforms as py_vision | |||
| from mindspore import log as logger | |||
| from util import visualize, diff_mse | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_random_choice_op(plot=False): | |||
| """ | |||
| Test RandomChoice in python transformations | |||
| """ | |||
| logger.info("test_random_choice_op") | |||
| # define map operations | |||
| transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)] | |||
| transforms1 = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomChoice(transforms_list), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform1 = py_vision.ComposeOp(transforms1) | |||
| transforms2 = [ | |||
| py_vision.Decode(), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform2 = py_vision.ComposeOp(transforms2) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(input_columns=["image"], operations=transform1()) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(input_columns=["image"], operations=transform2()) | |||
| image_choice = [] | |||
| image_original = [] | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_choice.append(image1) | |||
| image_original.append(image2) | |||
| if plot: | |||
| visualize(image_original, image_choice) | |||
| def test_random_choice_comp(plot=False): | |||
| """ | |||
| Test RandomChoice and compare with single CenterCrop results | |||
| """ | |||
| logger.info("test_random_choice_comp") | |||
| # define map operations | |||
| transforms_list = [py_vision.CenterCrop(64)] | |||
| transforms1 = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomChoice(transforms_list), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform1 = py_vision.ComposeOp(transforms1) | |||
| transforms2 = [ | |||
| py_vision.Decode(), | |||
| py_vision.CenterCrop(64), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform2 = py_vision.ComposeOp(transforms2) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(input_columns=["image"], operations=transform1()) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(input_columns=["image"], operations=transform2()) | |||
| image_choice = [] | |||
| image_original = [] | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_choice.append(image1) | |||
| image_original.append(image2) | |||
| mse = diff_mse(image1, image2) | |||
| assert mse == 0 | |||
| if plot: | |||
| visualize(image_original, image_choice) | |||
| def test_random_choice_exception_random_crop_badinput(): | |||
| """ | |||
| Test RandomChoice: hit error in RandomCrop with greater crop size, | |||
| expected to raise error | |||
| """ | |||
| logger.info("test_random_choice_exception_random_crop_badinput") | |||
| # define map operations | |||
| # note: crop size[5000, 5000] > image size[4032, 2268] | |||
| transforms_list = [py_vision.RandomCrop(5000)] | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomChoice(transforms_list), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| try: | |||
| _ = data.create_dict_iterator().get_next() | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Crop size" in str(e) | |||
| if __name__ == '__main__': | |||
| test_random_choice_op(plot=True) | |||
| test_random_choice_comp(plot=True) | |||
| test_random_choice_exception_random_crop_badinput() | |||
| @@ -0,0 +1,100 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing RandomOrder op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.py_transforms as py_vision | |||
| from mindspore import log as logger | |||
| from util import visualize, diff_mse, config_get_set_seed, \ | |||
| config_get_set_num_parallel_workers, save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_random_order_op(plot=False): | |||
| """ | |||
| Test RandomOrder in python transformations | |||
| """ | |||
| logger.info("test_random_order_op") | |||
| # define map operations | |||
| transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)] | |||
| transforms1 = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomOrder(transforms_list), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform1 = py_vision.ComposeOp(transforms1) | |||
| transforms2 = [ | |||
| py_vision.Decode(), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform2 = py_vision.ComposeOp(transforms2) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(input_columns=["image"], operations=transform1()) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(input_columns=["image"], operations=transform2()) | |||
| image_order = [] | |||
| image_original = [] | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_order.append(image1) | |||
| image_original.append(image2) | |||
| if plot: | |||
| visualize(image_original, image_order) | |||
| def test_random_order_md5(): | |||
| """ | |||
| Test RandomOrder op with md5 check | |||
| """ | |||
| logger.info("test_random_order_md5") | |||
| original_seed = config_get_set_seed(8) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # define map operations | |||
| transforms_list = [py_vision.RandomCrop(64), py_vision.RandomRotation(30)] | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomOrder(transforms_list), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| # check results with md5 comparison | |||
| filename = "random_order_01_result.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| # Restore configuration | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers((original_num_parallel_workers)) | |||
| if __name__ == '__main__': | |||
| test_random_order_op(plot=True) | |||
| test_random_order_md5() | |||
| @@ -0,0 +1,128 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing RandomPerspective op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.py_transforms as py_vision | |||
| from mindspore.dataset.transforms.vision.utils import Inter | |||
| from mindspore import log as logger | |||
| from util import visualize, save_and_check_md5, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_random_perspective_op(plot=False): | |||
| """ | |||
| Test RandomPerspective in python transformations | |||
| """ | |||
| logger.info("test_random_perspective_op") | |||
| # define map operations | |||
| transforms1 = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomPerspective(), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform1 = py_vision.ComposeOp(transforms1) | |||
| transforms2 = [ | |||
| py_vision.Decode(), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform2 = py_vision.ComposeOp(transforms2) | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(input_columns=["image"], operations=transform1()) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data2 = data2.map(input_columns=["image"], operations=transform2()) | |||
| image_perspective = [] | |||
| image_original = [] | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_perspective.append(image1) | |||
| image_original.append(image2) | |||
| if plot: | |||
| visualize(image_original, image_perspective) | |||
| def skip_test_random_perspective_md5(): | |||
| """ | |||
| Test RandomPerspective with md5 comparison | |||
| """ | |||
| logger.info("test_random_perspective_md5") | |||
| original_seed = config_get_set_seed(5) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # define map operations | |||
| transforms = [ | |||
| py_vision.Decode(), | |||
| py_vision.RandomPerspective(distortion_scale=0.3, prob=0.7, | |||
| interpolation=Inter.BILINEAR), | |||
| py_vision.ToTensor() | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| # Generate dataset | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| # check results with md5 comparison | |||
| filename = "random_perspective_01_result.npz" | |||
| save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| # Restore configuration | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers((original_num_parallel_workers)) | |||
| def test_random_perspective_exception_distortion_scale_range(): | |||
| """ | |||
| Test RandomPerspective: distortion_scale is not in [0, 1], expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_perspective_exception_distortion_scale_range") | |||
| try: | |||
| _ = py_vision.RandomPerspective(distortion_scale=1.5) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "Input is not within the required range" | |||
| def test_random_perspective_exception_prob_range(): | |||
| """ | |||
| Test RandomPerspective: prob is not in [0, 1], expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_perspective_exception_prob_range") | |||
| try: | |||
| _ = py_vision.RandomPerspective(prob=1.2) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "Input is not within the required range" | |||
| if __name__ == "__main__": | |||
| test_random_perspective_op(plot=True) | |||
| skip_test_random_perspective_md5() | |||
| test_random_perspective_exception_distortion_scale_range() | |||
| test_random_perspective_exception_prob_range() | |||