You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_map.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with mappable datasets
  17. """
  18. import os
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.vision.c_transforms as c_vision
  22. from mindspore import log as logger
  23. from util import save_and_check_md5
  24. DATA_DIR = "../data/dataset/testImageNetData/train/"
  25. GENERATE_GOLDEN = False
  26. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  27. def test_cache_map_basic1():
  28. """
  29. Test mappable leaf with cache op right over the leaf
  30. Repeat
  31. |
  32. Map(decode)
  33. |
  34. Cache
  35. |
  36. ImageFolder
  37. """
  38. logger.info("Test cache map basic 1")
  39. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  40. # This DATA_DIR only has 2 images in it
  41. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  42. decode_op = c_vision.Decode()
  43. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  44. ds1 = ds1.repeat(4)
  45. filename = "cache_map_01_result.npz"
  46. save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
  47. logger.info("test_cache_map_basic1 Ended.\n")
  48. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  49. def test_cache_map_basic2():
  50. """
  51. Test mappable leaf with the cache op later in the tree above the map(decode)
  52. Repeat
  53. |
  54. Cache
  55. |
  56. Map(decode)
  57. |
  58. ImageFolder
  59. """
  60. logger.info("Test cache map basic 2")
  61. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  62. # This DATA_DIR only has 2 images in it
  63. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  64. decode_op = c_vision.Decode()
  65. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  66. ds1 = ds1.repeat(4)
  67. filename = "cache_map_02_result.npz"
  68. save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
  69. logger.info("test_cache_map_basic2 Ended.\n")
  70. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  71. def test_cache_map_basic3():
  72. """
  73. Test a repeat under mappable cache
  74. Cache
  75. |
  76. Map(decode)
  77. |
  78. Repeat
  79. |
  80. ImageFolder
  81. """
  82. logger.info("Test cache basic 3")
  83. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  84. # This DATA_DIR only has 2 images in it
  85. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  86. decode_op = c_vision.Decode()
  87. ds1 = ds1.repeat(4)
  88. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  89. logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
  90. num_iter = 0
  91. for _ in ds1.create_dict_iterator(num_epochs=1):
  92. logger.info("get data from dataset")
  93. num_iter += 1
  94. logger.info("Number of data in ds1: {} ".format(num_iter))
  95. assert num_iter == 8
  96. logger.info('test_cache_basic3 Ended.\n')
  97. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  98. def test_cache_map_basic4():
  99. """
  100. Test different rows result in core dump
  101. """
  102. logger.info("Test cache basic 4")
  103. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  104. # This DATA_DIR only has 2 images in it
  105. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  106. decode_op = c_vision.Decode()
  107. ds1 = ds1.repeat(4)
  108. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  109. logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
  110. shape = ds1.output_shapes()
  111. logger.info(shape)
  112. num_iter = 0
  113. for _ in ds1.create_dict_iterator(num_epochs=1):
  114. logger.info("get data from dataset")
  115. num_iter += 1
  116. logger.info("Number of data in ds1: {} ".format(num_iter))
  117. assert num_iter == 8
  118. logger.info('test_cache_basic3 Ended.\n')
  119. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  120. def test_cache_map_failure1():
  121. """
  122. Test nested cache (failure)
  123. Repeat
  124. |
  125. Cache
  126. |
  127. Map(decode)
  128. |
  129. Cache
  130. |
  131. ImageFolder
  132. """
  133. logger.info("Test cache failure 1")
  134. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  135. # This DATA_DIR only has 2 images in it
  136. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  137. decode_op = c_vision.Decode()
  138. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  139. ds1 = ds1.repeat(4)
  140. try:
  141. num_iter = 0
  142. for _ in ds1.create_dict_iterator(num_epochs=1):
  143. num_iter += 1
  144. except RuntimeError as e:
  145. logger.info("Got an exception in DE: {}".format(str(e)))
  146. assert "Nested cache operations is not supported!" in str(e)
  147. assert num_iter == 0
  148. logger.info('test_cache_failure1 Ended.\n')
  149. if __name__ == '__main__':
  150. test_cache_map_basic1()
  151. logger.info("test_cache_map_basic1 success.")
  152. test_cache_map_basic2()
  153. logger.info("test_cache_map_basic2 success.")
  154. test_cache_map_basic3()
  155. logger.info("test_cache_map_basic3 success.")
  156. test_cache_map_basic4()
  157. logger.info("test_cache_map_basic3 success.")
  158. test_cache_map_failure1()
  159. logger.info("test_cache_map_failure1 success.")