Merge pull request !524 from alashkari/ua-opstags/v0.2.0-alpha
| @@ -1312,3 +1312,177 @@ class HsvToRgb: | |||||
| rgb_imgs (numpy.ndarray), Numpy RGB image with same shape of hsv_imgs. | rgb_imgs (numpy.ndarray), Numpy RGB image with same shape of hsv_imgs. | ||||
| """ | """ | ||||
| return util.hsv_to_rgbs(hsv_imgs, self.is_hwc) | return util.hsv_to_rgbs(hsv_imgs, self.is_hwc) | ||||
| class RandomColor: | |||||
| """ | |||||
| Adjust the color of the input PIL image by a random degree. | |||||
| Args: | |||||
| degrees (sequence): Range of random color adjustment degrees. | |||||
| It should be in (min, max) format (default=(0.1,1.9)). | |||||
| Examples: | |||||
| >>> py_transforms.ComposeOp([py_transforms.Decode(), | |||||
| >>> py_transforms.RandomColor(0.5,1.5), | |||||
| >>> py_transforms.ToTensor()]) | |||||
| """ | |||||
| def __init__(self, degrees=(0.1, 1.9)): | |||||
| self.degrees = degrees | |||||
| def __call__(self, img): | |||||
| """ | |||||
| Call method. | |||||
| Args: | |||||
| img (PIL Image): Image to be color adjusted. | |||||
| Returns: | |||||
| img (PIL Image), Color adjusted image. | |||||
| """ | |||||
| return util.random_color(img, self.degrees) | |||||
| class RandomSharpness: | |||||
| """ | |||||
| Adjust the sharpness of the input PIL image by a random degree. | |||||
| Args: | |||||
| degrees (sequence): Range of random sharpness adjustment degrees. | |||||
| It should be in (min, max) format (default=(0.1,1.9)). | |||||
| Examples: | |||||
| >>> py_transforms.ComposeOp([py_transforms.Decode(), | |||||
| >>> py_transforms.RandomColor(0.5,1.5), | |||||
| >>> py_transforms.ToTensor()]) | |||||
| """ | |||||
| def __init__(self, degrees=(0.1, 1.9)): | |||||
| self.degrees = degrees | |||||
| def __call__(self, img): | |||||
| """ | |||||
| Call method. | |||||
| Args: | |||||
| img (PIL Image): Image to be sharpness adjusted. | |||||
| Returns: | |||||
| img (PIL Image), Color adjusted image. | |||||
| """ | |||||
| return util.random_sharpness(img, self.degrees) | |||||
| class AutoContrast: | |||||
| """ | |||||
| Automatically maximize the contrast of the input PIL image. | |||||
| Examples: | |||||
| >>> py_transforms.ComposeOp([py_transforms.Decode(), | |||||
| >>> py_transforms.AutoContrast(), | |||||
| >>> py_transforms.ToTensor()]) | |||||
| """ | |||||
| def __call__(self, img): | |||||
| """ | |||||
| Call method. | |||||
| Args: | |||||
| img (PIL Image): Image to be augmented with AutoContrast. | |||||
| Returns: | |||||
| img (PIL Image), Augmented image. | |||||
| """ | |||||
| return util.auto_contrast(img) | |||||
| class Invert: | |||||
| """ | |||||
| Invert colors of input PIL image. | |||||
| Examples: | |||||
| >>> py_transforms.ComposeOp([py_transforms.Decode(), | |||||
| >>> py_transforms.Invert(), | |||||
| >>> py_transforms.ToTensor()]) | |||||
| """ | |||||
| def __call__(self, img): | |||||
| """ | |||||
| Call method. | |||||
| Args: | |||||
| img (PIL Image): Image to be color Inverted. | |||||
| Returns: | |||||
| img (PIL Image), Color inverted image. | |||||
| """ | |||||
| return util.invert_color(img) | |||||
| class Equalize: | |||||
| """ | |||||
| Equalize the histogram of input PIL image. | |||||
| Examples: | |||||
| >>> py_transforms.ComposeOp([py_transforms.Decode(), | |||||
| >>> py_transforms.Equalize(), | |||||
| >>> py_transforms.ToTensor()]) | |||||
| """ | |||||
| def __call__(self, img): | |||||
| """ | |||||
| Call method. | |||||
| Args: | |||||
| img (PIL Image): Image to be equalized. | |||||
| Returns: | |||||
| img (PIL Image), Equalized image. | |||||
| """ | |||||
| return util.equalize(img) | |||||
| class UniformAugment: | |||||
| """ | |||||
| Uniformly select and apply a number of transforms sequentially from | |||||
| a list of transforms. Randomly assigns a probability to each transform for | |||||
| each image to decide whether apply it or not. | |||||
| Args: | |||||
| transforms (list): List of transformations to be chosen from to apply. | |||||
| num_ops (int, optional): number of transforms to sequentially apply (default=2). | |||||
| Examples: | |||||
| >>> transforms_list = [py_transforms.CenterCrop(64), | |||||
| >>> py_transforms.RandomColor(), | |||||
| >>> py_transforms.RandomSharpness(), | |||||
| >>> py_transforms.RandomRotation(30)] | |||||
| >>> py_transforms.ComposeOp([py_transforms.Decode(), | |||||
| >>> py_transforms.UniformAugment(transforms_list), | |||||
| >>> py_transforms.ToTensor()]) | |||||
| """ | |||||
| def __init__(self, transforms, num_ops=2): | |||||
| self.transforms = transforms | |||||
| self.num_ops = num_ops | |||||
| def __call__(self, img): | |||||
| """ | |||||
| Call method. | |||||
| Args: | |||||
| img (PIL Image): Image to be applied transformation. | |||||
| Returns: | |||||
| img (PIL Image), Transformed image. | |||||
| """ | |||||
| return util.uniform_augment(img, self.transforms, self.num_ops) | |||||
| @@ -1408,3 +1408,160 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc): | |||||
| if batch_size == 0: | if batch_size == 0: | ||||
| return hsv_to_rgb(np_hsv_imgs, is_hwc) | return hsv_to_rgb(np_hsv_imgs, is_hwc) | ||||
| return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs]) | return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs]) | ||||
| def random_color(img, degrees): | |||||
| """ | |||||
| Adjust the color of the input PIL image by a random degree. | |||||
| Args: | |||||
| img (PIL Image): Image to be color adjusted. | |||||
| degrees (sequence): Range of random color adjustment degrees. | |||||
| It should be in (min, max) format (default=(0.1,1.9)). | |||||
| Returns: | |||||
| img (PIL Image), Color adjusted image. | |||||
| """ | |||||
| if not is_pil(img): | |||||
| raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | |||||
| if isinstance(degrees, (list, tuple)): | |||||
| if len(degrees) != 2: | |||||
| raise ValueError("Degrees must be a sequence length 2.") | |||||
| if degrees[0] < 0: | |||||
| raise ValueError("Degree value must be non-negative.") | |||||
| if degrees[0] > degrees[1]: | |||||
| raise ValueError("Degrees should be in (min,max) format. Got (max,min).") | |||||
| else: | |||||
| raise TypeError("Degrees must be a sequence in (min,max) format.") | |||||
| v = (degrees[1] - degrees[0]) * random.random() + degrees[0] | |||||
| return ImageEnhance.Color(img).enhance(v) | |||||
| def random_sharpness(img, degrees): | |||||
| """ | |||||
| Adjust the sharpness of the input PIL image by a random degree. | |||||
| Args: | |||||
| img (PIL Image): Image to be sharpness adjusted. | |||||
| degrees (sequence): Range of random sharpness adjustment degrees. | |||||
| It should be in (min, max) format (default=(0.1,1.9)). | |||||
| Returns: | |||||
| img (PIL Image), Sharpness adjusted image. | |||||
| """ | |||||
| if not is_pil(img): | |||||
| raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | |||||
| if isinstance(degrees, (list, tuple)): | |||||
| if len(degrees) != 2: | |||||
| raise ValueError("Degrees must be a sequence length 2.") | |||||
| if degrees[0] < 0: | |||||
| raise ValueError("Degree value must be non-negative.") | |||||
| if degrees[0] > degrees[1]: | |||||
| raise ValueError("Degrees should be in (min,max) format. Got (max,min).") | |||||
| else: | |||||
| raise TypeError("Degrees must be a sequence in (min,max) format.") | |||||
| v = (degrees[1] - degrees[0]) * random.random() + degrees[0] | |||||
| return ImageEnhance.Sharpness(img).enhance(v) | |||||
| def auto_contrast(img): | |||||
| """ | |||||
| Automatically maximize the contrast of the input PIL image. | |||||
| Args: | |||||
| img (PIL Image): Image to be augmented with AutoContrast. | |||||
| Returns: | |||||
| img (PIL Image), Augmented image. | |||||
| """ | |||||
| if not is_pil(img): | |||||
| raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | |||||
| return ImageOps.autocontrast(img) | |||||
| def invert_color(img): | |||||
| """ | |||||
| Invert colors of input PIL image. | |||||
| Args: | |||||
| img (PIL Image): Image to be color inverted. | |||||
| Returns: | |||||
| img (PIL Image), Color inverted image. | |||||
| """ | |||||
| if not is_pil(img): | |||||
| raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | |||||
| return ImageOps.invert(img) | |||||
| def equalize(img): | |||||
| """ | |||||
| Equalize the histogram of input PIL image. | |||||
| Args: | |||||
| img (PIL Image): Image to be equalized | |||||
| Returns: | |||||
| img (PIL Image), Equalized image. | |||||
| """ | |||||
| if not is_pil(img): | |||||
| raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | |||||
| return ImageOps.equalize(img) | |||||
| def uniform_augment(img, transforms, num_ops): | |||||
| """ | |||||
| Uniformly select and apply a number of transforms sequentially from | |||||
| a list of transforms. Randomly assigns a probability to each transform for | |||||
| each image to decide whether apply it or not. | |||||
| Args: | |||||
| img: Image to be applied transformation. | |||||
| transforms (list): List of transformations to be chosen from to apply. | |||||
| num_ops (int): number of transforms to sequentially aaply. | |||||
| Returns: | |||||
| img, Transformed image. | |||||
| """ | |||||
| if transforms is None: | |||||
| raise ValueError("transforms is not provided.") | |||||
| if not isinstance(transforms, list): | |||||
| raise ValueError("The transforms needs to be a list.") | |||||
| if not isinstance(num_ops, int): | |||||
| raise ValueError("Number of operations should be a positive integer.") | |||||
| if num_ops < 1: | |||||
| raise ValueError("Number of operators should equal or greater than one.") | |||||
| for _ in range(num_ops): | |||||
| AugmentOp = random.choice(transforms) | |||||
| pr = random.random() | |||||
| if random.random() < pr: | |||||
| img = AugmentOp(img.copy()) | |||||
| transforms.remove(AugmentOp) | |||||
| return img | |||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import matplotlib.pyplot as plt | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| def visualize(image_original, image_auto_contrast): | |||||
| """ | |||||
| visualizes the image using DE op and Numpy op | |||||
| """ | |||||
| num = len(image_auto_contrast) | |||||
| for i in range(num): | |||||
| plt.subplot(2, num, i + 1) | |||||
| plt.imshow(image_original[i]) | |||||
| plt.title("Original image") | |||||
| plt.subplot(2, num, i + num + 1) | |||||
| plt.imshow(image_auto_contrast[i]) | |||||
| plt.title("DE AutoContrast image") | |||||
| plt.show() | |||||
| def test_auto_contrast(plot=False): | |||||
| """ | |||||
| Test AutoContrast | |||||
| """ | |||||
| logger.info("Test AutoContrast") | |||||
| # Original Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.ToTensor()]) | |||||
| ds_original = ds.map(input_columns="image", | |||||
| operations=transforms_original()) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| # AutoContrast Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_auto_contrast = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.AutoContrast(), | |||||
| F.ToTensor()]) | |||||
| ds_auto_contrast = ds.map(input_columns="image", | |||||
| operations=transforms_auto_contrast()) | |||||
| ds_auto_contrast = ds_auto_contrast.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_auto_contrast): | |||||
| if idx == 0: | |||||
| images_auto_contrast = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_auto_contrast = np.append(images_auto_contrast, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = np.mean((images_auto_contrast[i]-images_original[i])**2) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize(images_original, images_auto_contrast) | |||||
| if __name__ == "__main__": | |||||
| test_auto_contrast(plot=True) | |||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import matplotlib.pyplot as plt | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| def visualize(image_original, image_equalize): | |||||
| """ | |||||
| visualizes the image using DE op and Numpy op | |||||
| """ | |||||
| num = len(image_equalize) | |||||
| for i in range(num): | |||||
| plt.subplot(2, num, i + 1) | |||||
| plt.imshow(image_original[i]) | |||||
| plt.title("Original image") | |||||
| plt.subplot(2, num, i + num + 1) | |||||
| plt.imshow(image_equalize[i]) | |||||
| plt.title("DE Color Equalized image") | |||||
| plt.show() | |||||
| def test_equalize(plot=False): | |||||
| """ | |||||
| Test Equalize | |||||
| """ | |||||
| logger.info("Test Equalize") | |||||
| # Original Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.ToTensor()]) | |||||
| ds_original = ds.map(input_columns="image", | |||||
| operations=transforms_original()) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| # Color Equalized Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_equalize = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.Equalize(), | |||||
| F.ToTensor()]) | |||||
| ds_equalize = ds.map(input_columns="image", | |||||
| operations=transforms_equalize()) | |||||
| ds_equalize = ds_equalize.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_equalize): | |||||
| if idx == 0: | |||||
| images_equalize = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_equalize = np.append(images_equalize, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = np.mean((images_equalize[i]-images_original[i])**2) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize(images_original, images_equalize) | |||||
| if __name__ == "__main__": | |||||
| test_equalize(plot=True) | |||||
| @@ -0,0 +1,100 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import matplotlib.pyplot as plt | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| def visualize(image_original, image_invert): | |||||
| """ | |||||
| visualizes the image using DE op and Numpy op | |||||
| """ | |||||
| num = len(image_invert) | |||||
| for i in range(num): | |||||
| plt.subplot(2, num, i + 1) | |||||
| plt.imshow(image_original[i]) | |||||
| plt.title("Original image") | |||||
| plt.subplot(2, num, i + num + 1) | |||||
| plt.imshow(image_invert[i]) | |||||
| plt.title("DE Color Inverted image") | |||||
| plt.show() | |||||
| def test_invert(plot=False): | |||||
| """ | |||||
| Test Invert | |||||
| """ | |||||
| logger.info("Test Invert") | |||||
| # Original Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.ToTensor()]) | |||||
| ds_original = ds.map(input_columns="image", | |||||
| operations=transforms_original()) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| # Color Inverted Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_invert = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.Invert(), | |||||
| F.ToTensor()]) | |||||
| ds_invert = ds.map(input_columns="image", | |||||
| operations=transforms_invert()) | |||||
| ds_invert = ds_invert.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_invert): | |||||
| if idx == 0: | |||||
| images_invert = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_invert = np.append(images_invert, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = np.mean((images_invert[i]-images_original[i])**2) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize(images_original, images_invert) | |||||
| if __name__ == "__main__": | |||||
| test_invert(plot=True) | |||||
| @@ -0,0 +1,102 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import matplotlib.pyplot as plt | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| def visualize(image_original, image_random_color): | |||||
| """ | |||||
| visualizes the image using DE op and Numpy op | |||||
| """ | |||||
| num = len(image_random_color) | |||||
| for i in range(num): | |||||
| plt.subplot(2, num, i + 1) | |||||
| plt.imshow(image_original[i]) | |||||
| plt.title("Original image") | |||||
| plt.subplot(2, num, i + num + 1) | |||||
| plt.imshow(image_random_color[i]) | |||||
| plt.title("DE Random Color image") | |||||
| plt.show() | |||||
| def test_random_color(degrees=(0.1,1.9), plot=False): | |||||
| """ | |||||
| Test RandomColor | |||||
| """ | |||||
| logger.info("Test RandomColor") | |||||
| # Original Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.ToTensor()]) | |||||
| ds_original = ds.map(input_columns="image", | |||||
| operations=transforms_original()) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| # Random Color Adjusted Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_random_color = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.RandomColor(degrees=degrees), | |||||
| F.ToTensor()]) | |||||
| ds_random_color = ds.map(input_columns="image", | |||||
| operations=transforms_random_color()) | |||||
| ds_random_color = ds_random_color.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_random_color): | |||||
| if idx == 0: | |||||
| images_random_color = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_random_color = np.append(images_random_color, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = np.mean((images_random_color[i]-images_original[i])**2) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize(images_original, images_random_color) | |||||
| if __name__ == "__main__": | |||||
| test_random_color() | |||||
| test_random_color(plot=True) | |||||
| test_random_color(degrees=(0.5,1.5), plot=True) | |||||
| @@ -0,0 +1,102 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import matplotlib.pyplot as plt | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| def visualize(image_original, image_random_sharpness): | |||||
| """ | |||||
| visualizes the image using DE op and Numpy op | |||||
| """ | |||||
| num = len(image_random_sharpness) | |||||
| for i in range(num): | |||||
| plt.subplot(2, num, i + 1) | |||||
| plt.imshow(image_original[i]) | |||||
| plt.title("Original image") | |||||
| plt.subplot(2, num, i + num + 1) | |||||
| plt.imshow(image_random_sharpness[i]) | |||||
| plt.title("DE Random Sharpness image") | |||||
| plt.show() | |||||
| def test_random_sharpness(degrees=(0.1,1.9), plot=False): | |||||
| """ | |||||
| Test RandomSharpness | |||||
| """ | |||||
| logger.info("Test RandomSharpness") | |||||
| # Original Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.ToTensor()]) | |||||
| ds_original = ds.map(input_columns="image", | |||||
| operations=transforms_original()) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| # Random Sharpness Adjusted Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_random_sharpness = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.RandomSharpness(degrees=degrees), | |||||
| F.ToTensor()]) | |||||
| ds_random_sharpness = ds.map(input_columns="image", | |||||
| operations=transforms_random_sharpness()) | |||||
| ds_random_sharpness = ds_random_sharpness.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_random_sharpness): | |||||
| if idx == 0: | |||||
| images_random_sharpness = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_random_sharpness = np.append(images_random_sharpness, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = np.mean((images_random_sharpness[i]-images_original[i])**2) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize(images_original, images_random_sharpness) | |||||
| if __name__ == "__main__": | |||||
| test_random_sharpness() | |||||
| test_random_sharpness(plot=True) | |||||
| test_random_sharpness(degrees=(0.5,1.5), plot=True) | |||||
| @@ -0,0 +1,107 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import matplotlib.pyplot as plt | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| def visualize(image_original, image_ua): | |||||
| """ | |||||
| visualizes the image using DE op and Numpy op | |||||
| """ | |||||
| num = len(image_ua) | |||||
| for i in range(num): | |||||
| plt.subplot(2, num, i + 1) | |||||
| plt.imshow(image_original[i]) | |||||
| plt.title("Original image") | |||||
| plt.subplot(2, num, i + num + 1) | |||||
| plt.imshow(image_ua[i]) | |||||
| plt.title("DE UniformAugment image") | |||||
| plt.show() | |||||
| def test_uniform_augment(plot=False, num_ops=2): | |||||
| """ | |||||
| Test UniformAugment | |||||
| """ | |||||
| logger.info("Test UniformAugment") | |||||
| # Original Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transforms_original = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.ToTensor()]) | |||||
| ds_original = ds.map(input_columns="image", | |||||
| operations=transforms_original()) | |||||
| ds_original = ds_original.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_original): | |||||
| if idx == 0: | |||||
| images_original = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_original = np.append(images_original, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| # UniformAugment Images | |||||
| ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||||
| transform_list = [F.RandomRotation(45), | |||||
| F.RandomColor(), | |||||
| F.RandomSharpness(), | |||||
| F.Invert(), | |||||
| F.AutoContrast(), | |||||
| F.Equalize()] | |||||
| transforms_ua = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.UniformAugment(transforms=transform_list, num_ops=num_ops), | |||||
| F.ToTensor()]) | |||||
| ds_ua = ds.map(input_columns="image", | |||||
| operations=transforms_ua()) | |||||
| ds_ua = ds_ua.batch(512) | |||||
| for idx, (image,label) in enumerate(ds_ua): | |||||
| if idx == 0: | |||||
| images_ua = np.transpose(image, (0, 2,3,1)) | |||||
| else: | |||||
| images_ua = np.append(images_ua, | |||||
| np.transpose(image, (0, 2,3,1)), | |||||
| axis=0) | |||||
| num_samples = images_original.shape[0] | |||||
| mse = np.zeros(num_samples) | |||||
| for i in range(num_samples): | |||||
| mse[i] = np.mean((images_ua[i]-images_original[i])**2) | |||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | |||||
| if plot: | |||||
| visualize(images_original, images_ua) | |||||
| if __name__ == "__main__": | |||||
| test_uniform_augment(num_ops=1) | |||||