You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_celeba.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright 2020 Huawei Technologies Co., Ltd.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import mindspore.dataset as ds
  15. import mindspore.dataset.vision.c_transforms as vision
  16. from mindspore import log as logger
  17. from mindspore.dataset.vision import Inter
  18. DATA_DIR = "../data/dataset/testCelebAData/"
  19. def test_celeba_dataset_label():
  20. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
  21. expect_labels = [
  22. [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
  23. 0, 0, 1],
  24. [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
  25. 0, 0, 1],
  26. [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
  27. 0, 0, 1],
  28. [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
  29. 0, 0, 1]]
  30. count = 0
  31. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  32. logger.info("----------image--------")
  33. logger.info(item["image"])
  34. logger.info("----------attr--------")
  35. logger.info(item["attr"])
  36. for index in range(len(expect_labels[count])):
  37. assert item["attr"][index] == expect_labels[count][index]
  38. count = count + 1
  39. assert count == 4
  40. def test_celeba_dataset_op():
  41. data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
  42. crop_size = (80, 80)
  43. resize_size = (24, 24)
  44. # define map operations
  45. data = data.repeat(2)
  46. center_crop = vision.CenterCrop(crop_size)
  47. resize_op = vision.Resize(resize_size, Inter.LINEAR) # Bilinear mode
  48. data = data.map(operations=center_crop, input_columns=["image"])
  49. data = data.map(operations=resize_op, input_columns=["image"])
  50. count = 0
  51. for item in data.create_dict_iterator(num_epochs=1):
  52. logger.info("----------image--------")
  53. logger.info(item["image"])
  54. count = count + 1
  55. assert count == 8
  56. def test_celeba_dataset_ext():
  57. ext = [".JPEG"]
  58. data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
  59. expect_labels = [
  60. [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
  61. 0, 1, 0, 1, 0, 0, 1],
  62. [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
  63. 0, 1, 0, 1, 0, 0, 1]]
  64. count = 0
  65. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  66. logger.info("----------image--------")
  67. logger.info(item["image"])
  68. logger.info("----------attr--------")
  69. logger.info(item["attr"])
  70. for index in range(len(expect_labels[count])):
  71. assert item["attr"][index] == expect_labels[count][index]
  72. count = count + 1
  73. assert count == 2
  74. def test_celeba_dataset_distribute():
  75. data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
  76. count = 0
  77. for item in data.create_dict_iterator(num_epochs=1):
  78. logger.info("----------image--------")
  79. logger.info(item["image"])
  80. logger.info("----------attr--------")
  81. logger.info(item["attr"])
  82. count = count + 1
  83. assert count == 2
  84. def test_celeba_get_dataset_size():
  85. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
  86. size = data.get_dataset_size()
  87. assert size == 4
  88. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train")
  89. size = data.get_dataset_size()
  90. assert size == 2
  91. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid")
  92. size = data.get_dataset_size()
  93. assert size == 1
  94. data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
  95. size = data.get_dataset_size()
  96. assert size == 1
  97. if __name__ == '__main__':
  98. test_celeba_dataset_label()
  99. test_celeba_dataset_op()
  100. test_celeba_dataset_ext()
  101. test_celeba_dataset_distribute()
  102. test_celeba_get_dataset_size()