You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_autocontrast.py 3.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import matplotlib.pyplot as plt
  16. import numpy as np
  17. import mindspore.dataset.engine as de
  18. import mindspore.dataset.transforms.vision.py_transforms as F
  19. from mindspore import log as logger
  20. DATA_DIR = "../data/dataset/testImageNetData/train/"
  21. def visualize(image_original, image_auto_contrast):
  22. """
  23. visualizes the image using DE op and Numpy op
  24. """
  25. num = len(image_auto_contrast)
  26. for i in range(num):
  27. plt.subplot(2, num, i + 1)
  28. plt.imshow(image_original[i])
  29. plt.title("Original image")
  30. plt.subplot(2, num, i + num + 1)
  31. plt.imshow(image_auto_contrast[i])
  32. plt.title("DE AutoContrast image")
  33. plt.show()
  34. def test_auto_contrast(plot=False):
  35. """
  36. Test AutoContrast
  37. """
  38. logger.info("Test AutoContrast")
  39. # Original Images
  40. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  41. transforms_original = F.ComposeOp([F.Decode(),
  42. F.Resize((224, 224)),
  43. F.ToTensor()])
  44. ds_original = ds.map(input_columns="image",
  45. operations=transforms_original())
  46. ds_original = ds_original.batch(512)
  47. for idx, (image, _) in enumerate(ds_original):
  48. if idx == 0:
  49. images_original = np.transpose(image, (0, 2, 3, 1))
  50. else:
  51. images_original = np.append(images_original,
  52. np.transpose(image, (0, 2, 3, 1)),
  53. axis=0)
  54. # AutoContrast Images
  55. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  56. transforms_auto_contrast = F.ComposeOp([F.Decode(),
  57. F.Resize((224, 224)),
  58. F.AutoContrast(),
  59. F.ToTensor()])
  60. ds_auto_contrast = ds.map(input_columns="image",
  61. operations=transforms_auto_contrast())
  62. ds_auto_contrast = ds_auto_contrast.batch(512)
  63. for idx, (image, _) in enumerate(ds_auto_contrast):
  64. if idx == 0:
  65. images_auto_contrast = np.transpose(image, (0, 2, 3, 1))
  66. else:
  67. images_auto_contrast = np.append(images_auto_contrast,
  68. np.transpose(image, (0, 2, 3, 1)),
  69. axis=0)
  70. num_samples = images_original.shape[0]
  71. mse = np.zeros(num_samples)
  72. for i in range(num_samples):
  73. mse[i] = np.mean((images_auto_contrast[i] - images_original[i]) ** 2)
  74. logger.info("MSE= {}".format(str(np.mean(mse))))
  75. if plot:
  76. visualize(images_original, images_auto_contrast)
  77. if __name__ == "__main__":
  78. test_auto_contrast(plot=True)