You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_uniform_augment.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import matplotlib.pyplot as plt
  17. from mindspore import log as logger
  18. import mindspore.dataset.engine as de
  19. import mindspore.dataset.transforms.vision.py_transforms as F
  20. DATA_DIR = "../data/dataset/testImageNetData/train/"
  21. def visualize(image_original, image_ua):
  22. """
  23. visualizes the image using DE op and Numpy op
  24. """
  25. num = len(image_ua)
  26. for i in range(num):
  27. plt.subplot(2, num, i + 1)
  28. plt.imshow(image_original[i])
  29. plt.title("Original image")
  30. plt.subplot(2, num, i + num + 1)
  31. plt.imshow(image_ua[i])
  32. plt.title("DE UniformAugment image")
  33. plt.show()
  34. def test_uniform_augment(plot=False, num_ops=2):
  35. """
  36. Test UniformAugment
  37. """
  38. logger.info("Test UniformAugment")
  39. # Original Images
  40. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  41. transforms_original = F.ComposeOp([F.Decode(),
  42. F.Resize((224,224)),
  43. F.ToTensor()])
  44. ds_original = ds.map(input_columns="image",
  45. operations=transforms_original())
  46. ds_original = ds_original.batch(512)
  47. for idx, (image,label) in enumerate(ds_original):
  48. if idx == 0:
  49. images_original = np.transpose(image, (0, 2,3,1))
  50. else:
  51. images_original = np.append(images_original,
  52. np.transpose(image, (0, 2,3,1)),
  53. axis=0)
  54. # UniformAugment Images
  55. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  56. transform_list = [F.RandomRotation(45),
  57. F.RandomColor(),
  58. F.RandomSharpness(),
  59. F.Invert(),
  60. F.AutoContrast(),
  61. F.Equalize()]
  62. transforms_ua = F.ComposeOp([F.Decode(),
  63. F.Resize((224,224)),
  64. F.UniformAugment(transforms=transform_list, num_ops=num_ops),
  65. F.ToTensor()])
  66. ds_ua = ds.map(input_columns="image",
  67. operations=transforms_ua())
  68. ds_ua = ds_ua.batch(512)
  69. for idx, (image,label) in enumerate(ds_ua):
  70. if idx == 0:
  71. images_ua = np.transpose(image, (0, 2,3,1))
  72. else:
  73. images_ua = np.append(images_ua,
  74. np.transpose(image, (0, 2,3,1)),
  75. axis=0)
  76. num_samples = images_original.shape[0]
  77. mse = np.zeros(num_samples)
  78. for i in range(num_samples):
  79. mse[i] = np.mean((images_ua[i]-images_original[i])**2)
  80. logger.info("MSE= {}".format(str(np.mean(mse))))
  81. if plot:
  82. visualize(images_original, images_ua)
  83. if __name__ == "__main__":
  84. test_uniform_augment(num_ops=1)