You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_random_auto_contrast.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing RandomAutoContrast op in DE
  17. """
  18. import numpy as np
  19. import mindspore.dataset as ds
  20. import mindspore.dataset.vision.c_transforms as c_vision
  21. from mindspore import log as logger
  22. from util import visualize_list, visualize_image, diff_mse
  23. image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg"
  24. data_dir = "../data/dataset/testImageNetData/train/"
  25. def test_random_auto_contrast_pipeline(plot=False):
  26. """
  27. Test RandomAutoContrast pipeline
  28. """
  29. logger.info("Test RandomAutoContrast pipeline")
  30. # Original Images
  31. data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  32. transforms_original = [c_vision.Decode(), c_vision.Resize(size=[224, 224])]
  33. ds_original = data_set.map(operations=transforms_original, input_columns="image")
  34. ds_original = ds_original.batch(512)
  35. for idx, (image, _) in enumerate(ds_original):
  36. if idx == 0:
  37. images_original = image.asnumpy()
  38. else:
  39. images_original = np.append(images_original,
  40. image.asnumpy(),
  41. axis=0)
  42. # Randomly Automatically Contrasted Images
  43. data_set1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  44. transform_random_auto_contrast = [c_vision.Decode(),
  45. c_vision.Resize(size=[224, 224]),
  46. c_vision.RandomAutoContrast(prob=0.6)]
  47. ds_random_auto_contrast = data_set1.map(operations=transform_random_auto_contrast, input_columns="image")
  48. ds_random_auto_contrast = ds_random_auto_contrast.batch(512)
  49. for idx, (image, _) in enumerate(ds_random_auto_contrast):
  50. if idx == 0:
  51. images_random_auto_contrast = image.asnumpy()
  52. else:
  53. images_random_auto_contrast = np.append(images_random_auto_contrast,
  54. image.asnumpy(),
  55. axis=0)
  56. if plot:
  57. visualize_list(images_original, images_random_auto_contrast)
  58. num_samples = images_original.shape[0]
  59. mse = np.zeros(num_samples)
  60. for i in range(num_samples):
  61. mse[i] = diff_mse(images_random_auto_contrast[i], images_original[i])
  62. logger.info("MSE= {}".format(str(np.mean(mse))))
  63. def test_random_auto_contrast_eager():
  64. """
  65. Test RandomAutoContrast eager.
  66. """
  67. img = np.fromfile(image_file, dtype=np.uint8)
  68. logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
  69. img = c_vision.Decode()(img)
  70. img_auto_contrast = c_vision.AutoContrast(1.0, None)(img)
  71. img_random_auto_contrast = c_vision.RandomAutoContrast(1.0, None, 1.0)(img)
  72. logger.info("Image.type: {}, Image.shape: {}".format(type(img_auto_contrast), img_random_auto_contrast.shape))
  73. assert img_auto_contrast.all() == img_random_auto_contrast.all()
  74. def test_random_auto_contrast_comp(plot=False):
  75. """
  76. Test RandomAutoContrast op compared with AutoContrast op.
  77. """
  78. random_auto_contrast_op = c_vision.RandomAutoContrast(prob=1.0)
  79. auto_contrast_op = c_vision.AutoContrast()
  80. dataset1 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
  81. for item in dataset1.create_dict_iterator(num_epochs=1, output_numpy=True):
  82. image = item['image']
  83. dataset1.map(operations=random_auto_contrast_op, input_columns=['image'])
  84. dataset2 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
  85. dataset2.map(operations=auto_contrast_op, input_columns=['image'])
  86. for item1, item2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
  87. dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  88. image_random_auto_contrast = item1['image']
  89. image_auto_contrast = item2['image']
  90. mse = diff_mse(image_auto_contrast, image_random_auto_contrast)
  91. assert mse == 0
  92. logger.info("mse: {}".format(mse))
  93. if plot:
  94. visualize_image(image, image_random_auto_contrast, mse, image_auto_contrast)
  95. def test_random_auto_contrast_invalid_prob():
  96. """
  97. Test RandomAutoContrast Op with invalid prob parameter.
  98. """
  99. logger.info("test_random_auto_contrast_invalid_prob")
  100. dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
  101. try:
  102. random_auto_contrast_op = c_vision.RandomAutoContrast(prob=1.5)
  103. dataset = dataset.map(operations=random_auto_contrast_op, input_columns=['image'])
  104. except ValueError as e:
  105. logger.info("Got an exception in DE: {}".format(str(e)))
  106. assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
  107. def test_random_auto_contrast_invalid_ignore():
  108. """
  109. Test RandomAutoContrast Op with invalid ignore parameter.
  110. """
  111. logger.info("test_random_auto_contrast_invalid_ignore")
  112. try:
  113. data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  114. data_set = data_set.map(operations=[c_vision.Decode(),
  115. c_vision.Resize((224, 224)),
  116. lambda img: np.array(img[:, :, 0])], input_columns=["image"])
  117. # invalid ignore
  118. data_set = data_set.map(operations=c_vision.RandomAutoContrast(ignore=255.5), input_columns="image")
  119. except TypeError as error:
  120. logger.info("Got an exception in DE: {}".format(str(error)))
  121. assert "Argument ignore with value 255.5 is not of type" in str(error)
  122. try:
  123. data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  124. data_set = data_set.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224)),
  125. lambda img: np.array(img[:, :, 0])], input_columns=["image"])
  126. # invalid ignore
  127. data_set = data_set.map(operations=c_vision.RandomAutoContrast(ignore=(10, 100)), input_columns="image")
  128. except TypeError as error:
  129. logger.info("Got an exception in DE: {}".format(str(error)))
  130. assert "Argument ignore with value (10,100) is not of type" in str(error)
  131. def test_random_auto_contrast_invalid_cutoff():
  132. """
  133. Test RandomAutoContrast Op with invalid cutoff parameter.
  134. """
  135. logger.info("test_random_auto_contrast_invalid_cutoff")
  136. try:
  137. data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  138. data_set = data_set.map(operations=[c_vision.Decode(),
  139. c_vision.Resize((224, 224)),
  140. lambda img: np.array(img[:, :, 0])], input_columns=["image"])
  141. # invalid cutoff
  142. data_set = data_set.map(operations=c_vision.RandomAutoContrast(cutoff=-10.0), input_columns="image")
  143. except ValueError as error:
  144. logger.info("Got an exception in DE: {}".format(str(error)))
  145. assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
  146. try:
  147. data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  148. data_set = data_set.map(operations=[c_vision.Decode(),
  149. c_vision.Resize((224, 224)),
  150. lambda img: np.array(img[:, :, 0])], input_columns=["image"])
  151. # invalid cutoff
  152. data_set = data_set.map(operations=c_vision.RandomAutoContrast(cutoff=120.0), input_columns="image")
  153. except ValueError as error:
  154. logger.info("Got an exception in DE: {}".format(str(error)))
  155. assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
  156. def test_random_auto_contrast_one_channel():
  157. """
  158. Feature: RandomAutoContrast
  159. Description: test with one channel images
  160. Expectation: raise errors as expected
  161. """
  162. logger.info("test_random_auto_contrast_one_channel")
  163. c_op = c_vision.RandomAutoContrast()
  164. try:
  165. data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  166. data_set = data_set.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224)),
  167. lambda img: np.array(img[:, :, 0])], input_columns=["image"])
  168. data_set = data_set.map(operations=c_op, input_columns="image")
  169. except RuntimeError as e:
  170. logger.info("Got an exception in DE: {}".format(str(e)))
  171. assert "image shape is incorrect, expected num of channels is 3." in str(e)
  172. def test_random_auto_contrast_four_dim():
  173. """
  174. Feature: RandomAutoContrast
  175. Description: test with four dimension images
  176. Expectation: raise errors as expected
  177. """
  178. logger.info("test_random_auto_contrast_four_dim")
  179. c_op = c_vision.RandomAutoContrast()
  180. try:
  181. data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  182. data_set = data_set.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224)),
  183. lambda img: np.array(img[2, 200, 10, 32])], input_columns=["image"])
  184. data_set = data_set.map(operations=c_op, input_columns="image")
  185. except ValueError as e:
  186. logger.info("Got an exception in DE: {}".format(str(e)))
  187. assert "image shape is not <H,W,C>" in str(e)
  188. def test_random_auto_contrast_invalid_input():
  189. """
  190. Feature: RandomAutoContrast
  191. Description: test with images in uint32 type
  192. Expectation: raise errors as expected
  193. """
  194. logger.info("test_random_invert_invalid_input")
  195. c_op = c_vision.RandomAutoContrast()
  196. try:
  197. data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  198. data_set = data_set.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224)),
  199. lambda img: np.array(img[2, 32, 3], dtype=uint32)], input_columns=["image"])
  200. data_set = data_set.map(operations=c_op, input_columns="image")
  201. except TypeError as e:
  202. logger.info("Got an exception in DE: {}".format(str(e)))
  203. assert "Cannot convert from OpenCV type, unknown CV type" in str(e)
  204. if __name__ == "__main__":
  205. test_random_auto_contrast_pipeline(plot=True)
  206. test_random_auto_contrast_eager()
  207. test_random_auto_contrast_comp(plot=True)
  208. test_random_auto_contrast_invalid_prob()
  209. test_random_auto_contrast_invalid_ignore()
  210. test_random_auto_contrast_invalid_cutoff()
  211. test_random_auto_contrast_one_channel()
  212. test_random_auto_contrast_four_dim()
  213. test_random_auto_contrast_invalid_input()