You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_invert.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import matplotlib.pyplot as plt
  17. from mindspore import log as logger
  18. import mindspore.dataset.engine as de
  19. import mindspore.dataset.transforms.vision.py_transforms as F
  20. DATA_DIR = "../data/dataset/testImageNetData/train/"
  21. def visualize(image_original, image_invert):
  22. """
  23. visualizes the image using DE op and Numpy op
  24. """
  25. num = len(image_invert)
  26. for i in range(num):
  27. plt.subplot(2, num, i + 1)
  28. plt.imshow(image_original[i])
  29. plt.title("Original image")
  30. plt.subplot(2, num, i + num + 1)
  31. plt.imshow(image_invert[i])
  32. plt.title("DE Color Inverted image")
  33. plt.show()
  34. def test_invert(plot=False):
  35. """
  36. Test Invert
  37. """
  38. logger.info("Test Invert")
  39. # Original Images
  40. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  41. transforms_original = F.ComposeOp([F.Decode(),
  42. F.Resize((224,224)),
  43. F.ToTensor()])
  44. ds_original = ds.map(input_columns="image",
  45. operations=transforms_original())
  46. ds_original = ds_original.batch(512)
  47. for idx, (image,label) in enumerate(ds_original):
  48. if idx == 0:
  49. images_original = np.transpose(image, (0, 2,3,1))
  50. else:
  51. images_original = np.append(images_original,
  52. np.transpose(image, (0, 2,3,1)),
  53. axis=0)
  54. # Color Inverted Images
  55. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  56. transforms_invert = F.ComposeOp([F.Decode(),
  57. F.Resize((224,224)),
  58. F.Invert(),
  59. F.ToTensor()])
  60. ds_invert = ds.map(input_columns="image",
  61. operations=transforms_invert())
  62. ds_invert = ds_invert.batch(512)
  63. for idx, (image,label) in enumerate(ds_invert):
  64. if idx == 0:
  65. images_invert = np.transpose(image, (0, 2,3,1))
  66. else:
  67. images_invert = np.append(images_invert,
  68. np.transpose(image, (0, 2,3,1)),
  69. axis=0)
  70. num_samples = images_original.shape[0]
  71. mse = np.zeros(num_samples)
  72. for i in range(num_samples):
  73. mse[i] = np.mean((images_invert[i]-images_original[i])**2)
  74. logger.info("MSE= {}".format(str(np.mean(mse))))
  75. if plot:
  76. visualize(images_original, images_invert)
  77. if __name__ == "__main__":
  78. test_invert(plot=True)