diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index 8d81f8f3b0..51bea80b21 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -1312,3 +1312,177 @@ class HsvToRgb: rgb_imgs (numpy.ndarray), Numpy RGB image with same shape of hsv_imgs. """ return util.hsv_to_rgbs(hsv_imgs, self.is_hwc) + + +class RandomColor: + """ + Adjust the color of the input PIL image by a random degree. + + Args: + degrees (sequence): Range of random color adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.RandomColor(0.5,1.5), + >>> py_transforms.ToTensor()]) + """ + + def __init__(self, degrees=(0.1, 1.9)): + self.degrees = degrees + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be color adjusted. + + Returns: + img (PIL Image), Color adjusted image. + """ + + return util.random_color(img, self.degrees) + +class RandomSharpness: + """ + Adjust the sharpness of the input PIL image by a random degree. + + Args: + degrees (sequence): Range of random sharpness adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.RandomColor(0.5,1.5), + >>> py_transforms.ToTensor()]) + + """ + + def __init__(self, degrees=(0.1, 1.9)): + self.degrees = degrees + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be sharpness adjusted. + + Returns: + img (PIL Image), Color adjusted image. + """ + + return util.random_sharpness(img, self.degrees) + + +class AutoContrast: + """ + Automatically maximize the contrast of the input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.AutoContrast(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be augmented with AutoContrast. + + Returns: + img (PIL Image), Augmented image. + """ + + return util.auto_contrast(img) + + +class Invert: + """ + Invert colors of input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.Invert(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be color Inverted. + + Returns: + img (PIL Image), Color inverted image. + """ + + return util.invert_color(img) + + +class Equalize: + """ + Equalize the histogram of input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.Equalize(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be equalized. + + Returns: + img (PIL Image), Equalized image. + """ + + return util.equalize(img) + + +class UniformAugment: + """ + Uniformly select and apply a number of transforms sequentially from + a list of transforms. Randomly assigns a probability to each transform for + each image to decide whether apply it or not. + + Args: + transforms (list): List of transformations to be chosen from to apply. + num_ops (int, optional): number of transforms to sequentially apply (default=2). + + Examples: + >>> transforms_list = [py_transforms.CenterCrop(64), + >>> py_transforms.RandomColor(), + >>> py_transforms.RandomSharpness(), + >>> py_transforms.RandomRotation(30)] + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.UniformAugment(transforms_list), + >>> py_transforms.ToTensor()]) + """ + + def __init__(self, transforms, num_ops=2): + self.transforms = transforms + self.num_ops = num_ops + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be applied transformation. + + Returns: + img (PIL Image), Transformed image. + """ + return util.uniform_augment(img, self.transforms, self.num_ops) diff --git a/mindspore/dataset/transforms/vision/py_transforms_util.py b/mindspore/dataset/transforms/vision/py_transforms_util.py index 10c71bbe38..54fb4c8274 100644 --- a/mindspore/dataset/transforms/vision/py_transforms_util.py +++ b/mindspore/dataset/transforms/vision/py_transforms_util.py @@ -1408,3 +1408,160 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc): if batch_size == 0: return hsv_to_rgb(np_hsv_imgs, is_hwc) return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs]) + + +def random_color(img, degrees): + + """ + Adjust the color of the input PIL image by a random degree. + + Args: + img (PIL Image): Image to be color adjusted. + degrees (sequence): Range of random color adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Returns: + img (PIL Image), Color adjusted image. + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if isinstance(degrees, (list, tuple)): + if len(degrees) != 2: + raise ValueError("Degrees must be a sequence length 2.") + if degrees[0] < 0: + raise ValueError("Degree value must be non-negative.") + if degrees[0] > degrees[1]: + raise ValueError("Degrees should be in (min,max) format. Got (max,min).") + + else: + raise TypeError("Degrees must be a sequence in (min,max) format.") + + v = (degrees[1] - degrees[0]) * random.random() + degrees[0] + return ImageEnhance.Color(img).enhance(v) + + +def random_sharpness(img, degrees): + + """ + Adjust the sharpness of the input PIL image by a random degree. + + Args: + img (PIL Image): Image to be sharpness adjusted. + degrees (sequence): Range of random sharpness adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Returns: + img (PIL Image), Sharpness adjusted image. + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if isinstance(degrees, (list, tuple)): + if len(degrees) != 2: + raise ValueError("Degrees must be a sequence length 2.") + if degrees[0] < 0: + raise ValueError("Degree value must be non-negative.") + if degrees[0] > degrees[1]: + raise ValueError("Degrees should be in (min,max) format. Got (max,min).") + + else: + raise TypeError("Degrees must be a sequence in (min,max) format.") + + v = (degrees[1] - degrees[0]) * random.random() + degrees[0] + return ImageEnhance.Sharpness(img).enhance(v) + + +def auto_contrast(img): + + """ + Automatically maximize the contrast of the input PIL image. + + Args: + img (PIL Image): Image to be augmented with AutoContrast. + + Returns: + img (PIL Image), Augmented image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.autocontrast(img) + + +def invert_color(img): + + """ + Invert colors of input PIL image. + + Args: + img (PIL Image): Image to be color inverted. + + Returns: + img (PIL Image), Color inverted image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.invert(img) + + +def equalize(img): + + """ + Equalize the histogram of input PIL image. + + Args: + img (PIL Image): Image to be equalized + + Returns: + img (PIL Image), Equalized image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.equalize(img) + + +def uniform_augment(img, transforms, num_ops): + + """ + Uniformly select and apply a number of transforms sequentially from + a list of transforms. Randomly assigns a probability to each transform for + each image to decide whether apply it or not. + + Args: + img: Image to be applied transformation. + transforms (list): List of transformations to be chosen from to apply. + num_ops (int): number of transforms to sequentially aaply. + + Returns: + img, Transformed image. + """ + + if transforms is None: + raise ValueError("transforms is not provided.") + if not isinstance(transforms, list): + raise ValueError("The transforms needs to be a list.") + + if not isinstance(num_ops, int): + raise ValueError("Number of operations should be a positive integer.") + if num_ops < 1: + raise ValueError("Number of operators should equal or greater than one.") + + for _ in range(num_ops): + AugmentOp = random.choice(transforms) + pr = random.random() + if random.random() < pr: + img = AugmentOp(img.copy()) + transforms.remove(AugmentOp) + + return img diff --git a/tests/ut/python/dataset/test_autocontrast.py b/tests/ut/python/dataset/test_autocontrast.py new file mode 100644 index 0000000000..7dba2f21f6 --- /dev/null +++ b/tests/ut/python/dataset/test_autocontrast.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_auto_contrast): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_auto_contrast) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_auto_contrast[i]) + plt.title("DE AutoContrast image") + + plt.show() + + +def test_auto_contrast(plot=False): + """ + Test AutoContrast + """ + logger.info("Test AutoContrast") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # AutoContrast Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_auto_contrast = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.AutoContrast(), + F.ToTensor()]) + + ds_auto_contrast = ds.map(input_columns="image", + operations=transforms_auto_contrast()) + + ds_auto_contrast = ds_auto_contrast.batch(512) + + for idx, (image,label) in enumerate(ds_auto_contrast): + if idx == 0: + images_auto_contrast = np.transpose(image, (0, 2,3,1)) + else: + images_auto_contrast = np.append(images_auto_contrast, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_auto_contrast[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_auto_contrast) + + +if __name__ == "__main__": + test_auto_contrast(plot=True) + diff --git a/tests/ut/python/dataset/test_equalize.py b/tests/ut/python/dataset/test_equalize.py new file mode 100644 index 0000000000..077c316d67 --- /dev/null +++ b/tests/ut/python/dataset/test_equalize.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_equalize): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_equalize) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_equalize[i]) + plt.title("DE Color Equalized image") + + plt.show() + + +def test_equalize(plot=False): + """ + Test Equalize + """ + logger.info("Test Equalize") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Color Equalized Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_equalize = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.Equalize(), + F.ToTensor()]) + + ds_equalize = ds.map(input_columns="image", + operations=transforms_equalize()) + + ds_equalize = ds_equalize.batch(512) + + for idx, (image,label) in enumerate(ds_equalize): + if idx == 0: + images_equalize = np.transpose(image, (0, 2,3,1)) + else: + images_equalize = np.append(images_equalize, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_equalize[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_equalize) + + +if __name__ == "__main__": + test_equalize(plot=True) + diff --git a/tests/ut/python/dataset/test_invert.py b/tests/ut/python/dataset/test_invert.py new file mode 100644 index 0000000000..a1bfd63431 --- /dev/null +++ b/tests/ut/python/dataset/test_invert.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + +def visualize(image_original, image_invert): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_invert) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_invert[i]) + plt.title("DE Color Inverted image") + + plt.show() + + +def test_invert(plot=False): + """ + Test Invert + """ + logger.info("Test Invert") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Color Inverted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_invert = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.Invert(), + F.ToTensor()]) + + ds_invert = ds.map(input_columns="image", + operations=transforms_invert()) + + ds_invert = ds_invert.batch(512) + + for idx, (image,label) in enumerate(ds_invert): + if idx == 0: + images_invert = np.transpose(image, (0, 2,3,1)) + else: + images_invert = np.append(images_invert, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_invert[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_invert) + + +if __name__ == "__main__": + test_invert(plot=True) + diff --git a/tests/ut/python/dataset/test_random_color.py b/tests/ut/python/dataset/test_random_color.py new file mode 100644 index 0000000000..9472b7e35a --- /dev/null +++ b/tests/ut/python/dataset/test_random_color.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_random_color): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_random_color) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_random_color[i]) + plt.title("DE Random Color image") + + plt.show() + + +def test_random_color(degrees=(0.1,1.9), plot=False): + """ + Test RandomColor + """ + logger.info("Test RandomColor") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Random Color Adjusted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_random_color = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.RandomColor(degrees=degrees), + F.ToTensor()]) + + ds_random_color = ds.map(input_columns="image", + operations=transforms_random_color()) + + ds_random_color = ds_random_color.batch(512) + + for idx, (image,label) in enumerate(ds_random_color): + if idx == 0: + images_random_color = np.transpose(image, (0, 2,3,1)) + else: + images_random_color = np.append(images_random_color, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_random_color[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_random_color) + + +if __name__ == "__main__": + test_random_color() + test_random_color(plot=True) + test_random_color(degrees=(0.5,1.5), plot=True) diff --git a/tests/ut/python/dataset/test_random_sharpness.py b/tests/ut/python/dataset/test_random_sharpness.py new file mode 100644 index 0000000000..949a658597 --- /dev/null +++ b/tests/ut/python/dataset/test_random_sharpness.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_random_sharpness): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_random_sharpness) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_random_sharpness[i]) + plt.title("DE Random Sharpness image") + + plt.show() + + +def test_random_sharpness(degrees=(0.1,1.9), plot=False): + """ + Test RandomSharpness + """ + logger.info("Test RandomSharpness") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Random Sharpness Adjusted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_random_sharpness = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.RandomSharpness(degrees=degrees), + F.ToTensor()]) + + ds_random_sharpness = ds.map(input_columns="image", + operations=transforms_random_sharpness()) + + ds_random_sharpness = ds_random_sharpness.batch(512) + + for idx, (image,label) in enumerate(ds_random_sharpness): + if idx == 0: + images_random_sharpness = np.transpose(image, (0, 2,3,1)) + else: + images_random_sharpness = np.append(images_random_sharpness, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_random_sharpness[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_random_sharpness) + + +if __name__ == "__main__": + test_random_sharpness() + test_random_sharpness(plot=True) + test_random_sharpness(degrees=(0.5,1.5), plot=True) diff --git a/tests/ut/python/dataset/test_uniform_augment.py b/tests/ut/python/dataset/test_uniform_augment.py new file mode 100644 index 0000000000..ce0490336e --- /dev/null +++ b/tests/ut/python/dataset/test_uniform_augment.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + +def visualize(image_original, image_ua): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_ua) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_ua[i]) + plt.title("DE UniformAugment image") + + plt.show() + + +def test_uniform_augment(plot=False, num_ops=2): + """ + Test UniformAugment + """ + logger.info("Test UniformAugment") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # UniformAugment Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transform_list = [F.RandomRotation(45), + F.RandomColor(), + F.RandomSharpness(), + F.Invert(), + F.AutoContrast(), + F.Equalize()] + + transforms_ua = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.UniformAugment(transforms=transform_list, num_ops=num_ops), + F.ToTensor()]) + + ds_ua = ds.map(input_columns="image", + operations=transforms_ua()) + + ds_ua = ds_ua.batch(512) + + for idx, (image,label) in enumerate(ds_ua): + if idx == 0: + images_ua = np.transpose(image, (0, 2,3,1)) + else: + images_ua = np.append(images_ua, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_ua[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_ua) + + +if __name__ == "__main__": + test_uniform_augment(num_ops=1) +