You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_celeba.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import mindspore.dataset as ds
  15. import mindspore.dataset.vision.c_transforms as vision
  16. from mindspore import log as logger
  17. from mindspore.dataset.vision import Inter
  18. DATA_DIR = "../data/dataset/testCelebAData/"
  19. def test_celeba_dataset_label():
  20. """
  21. Test CelebA dataset with labels
  22. """
  23. logger.info("Test CelebA labels")
  24. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
  25. expect_labels = [
  26. [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
  27. 0, 0, 1],
  28. [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
  29. 0, 0, 1],
  30. [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
  31. 0, 0, 1],
  32. [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
  33. 0, 0, 1]]
  34. count = 0
  35. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  36. logger.info("----------image--------")
  37. logger.info(item["image"])
  38. logger.info("----------attr--------")
  39. logger.info(item["attr"])
  40. for index in range(len(expect_labels[count])):
  41. assert item["attr"][index] == expect_labels[count][index]
  42. count = count + 1
  43. assert count == 4
  44. def test_celeba_dataset_op():
  45. """
  46. Test CelebA dataset with decode
  47. """
  48. logger.info("Test CelebA with decode")
  49. data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
  50. crop_size = (80, 80)
  51. resize_size = (24, 24)
  52. # define map operations
  53. data = data.repeat(2)
  54. center_crop = vision.CenterCrop(crop_size)
  55. resize_op = vision.Resize(resize_size, Inter.LINEAR) # Bilinear mode
  56. data = data.map(operations=center_crop, input_columns=["image"])
  57. data = data.map(operations=resize_op, input_columns=["image"])
  58. count = 0
  59. for item in data.create_dict_iterator(num_epochs=1):
  60. logger.info("----------image--------")
  61. logger.info(item["image"])
  62. count = count + 1
  63. assert count == 8
  64. def test_celeba_dataset_ext():
  65. """
  66. Test CelebA dataset with extension
  67. """
  68. logger.info("Test CelebA extension option")
  69. ext = [".JPEG"]
  70. data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
  71. expect_labels = [
  72. [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
  73. 0, 1, 0, 1, 0, 0, 1],
  74. [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
  75. 0, 1, 0, 1, 0, 0, 1]]
  76. count = 0
  77. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  78. logger.info("----------image--------")
  79. logger.info(item["image"])
  80. logger.info("----------attr--------")
  81. logger.info(item["attr"])
  82. for index in range(len(expect_labels[count])):
  83. assert item["attr"][index] == expect_labels[count][index]
  84. count = count + 1
  85. assert count == 2
  86. def test_celeba_dataset_distribute():
  87. """
  88. Test CelebA dataset with distributed options
  89. """
  90. logger.info("Test CelebA with sharding")
  91. data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
  92. count = 0
  93. for item in data.create_dict_iterator(num_epochs=1):
  94. logger.info("----------image--------")
  95. logger.info(item["image"])
  96. logger.info("----------attr--------")
  97. logger.info(item["attr"])
  98. count = count + 1
  99. assert count == 2
  100. def test_celeba_get_dataset_size():
  101. """
  102. Test CelebA dataset get dataset size
  103. """
  104. logger.info("Test CelebA get dataset size")
  105. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
  106. size = data.get_dataset_size()
  107. assert size == 4
  108. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train")
  109. size = data.get_dataset_size()
  110. assert size == 2
  111. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid")
  112. size = data.get_dataset_size()
  113. assert size == 1
  114. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
  115. size = data.get_dataset_size()
  116. assert size == 1
  117. def test_celeba_dataset_exception_file_path():
  118. """
  119. Test CelebA dataset with bad file path
  120. """
  121. logger.info("Test CelebA with bad file path")
  122. def exception_func(item):
  123. raise Exception("Error occur!")
  124. try:
  125. data = ds.CelebADataset(DATA_DIR, shuffle=False)
  126. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  127. for _ in data.create_dict_iterator():
  128. pass
  129. assert False
  130. except RuntimeError as e:
  131. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  132. try:
  133. data = ds.CelebADataset(DATA_DIR, shuffle=False)
  134. data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  135. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  136. for _ in data.create_dict_iterator():
  137. pass
  138. assert False
  139. except RuntimeError as e:
  140. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  141. try:
  142. data = ds.CelebADataset(DATA_DIR, shuffle=False)
  143. data = data.map(operations=exception_func, input_columns=["attr"], num_parallel_workers=1)
  144. for _ in data.create_dict_iterator():
  145. pass
  146. assert False
  147. except RuntimeError as e:
  148. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  149. def test_celeba_sampler_exception():
  150. """
  151. Test CelebA with bad sampler input
  152. """
  153. logger.info("Test CelebA with bad sampler input")
  154. try:
  155. data = ds.CelebADataset(DATA_DIR, sampler="")
  156. for _ in data.create_dict_iterator():
  157. pass
  158. assert False
  159. except TypeError as e:
  160. assert "Unsupported sampler object of type (<class 'str'>)" in str(e)
  161. if __name__ == '__main__':
  162. test_celeba_dataset_label()
  163. test_celeba_dataset_op()
  164. test_celeba_dataset_ext()
  165. test_celeba_dataset_distribute()
  166. test_celeba_get_dataset_size()
  167. test_celeba_dataset_exception_file_path()
  168. test_celeba_sampler_exception()