You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sampler_chain.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.transforms.c_transforms as c_transforms
  18. from mindspore import log as logger
  19. from util import save_and_check_md5
  20. GENERATE_GOLDEN = False
  21. def test_numpyslices_sampler_no_chain():
  22. """
  23. Test NumpySlicesDataset with sampler, no chain
  24. """
  25. logger.info("test_numpyslices_sampler_no_chain")
  26. # Create NumpySlicesDataset with sampler, no chain
  27. np_data = [1, 2, 3, 4]
  28. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  29. data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  30. # Verify dataset size
  31. data1_size = data1.get_dataset_size()
  32. logger.info("dataset size is: {}".format(data1_size))
  33. assert data1_size == 2
  34. # Verify number of rows
  35. assert sum([1 for _ in data1]) == 2
  36. # Verify dataset contents
  37. res = []
  38. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  39. logger.info("item: {}".format(item))
  40. res.append(item)
  41. logger.info("dataset: {}".format(res))
  42. def test_numpyslices_sampler_chain():
  43. """
  44. Test NumpySlicesDataset sampler chain
  45. """
  46. logger.info("test_numpyslices_sampler_chain")
  47. # Create NumpySlicesDataset with sampler chain
  48. # Use 1 statement to add child sampler
  49. np_data = [1, 2, 3, 4]
  50. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  51. sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  52. data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  53. # Verify dataset size
  54. data1_size = data1.get_dataset_size()
  55. logger.info("dataset size is: {}".format(data1_size))
  56. assert data1_size == 4
  57. # Verify number of rows
  58. assert sum([1 for _ in data1]) == 4
  59. # Verify dataset contents
  60. res = []
  61. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  62. logger.info("item: {}".format(item))
  63. res.append(item)
  64. logger.info("dataset: {}".format(res))
  65. def test_numpyslices_sampler_chain2():
  66. """
  67. Test NumpySlicesDataset sampler chain
  68. """
  69. logger.info("test_numpyslices_sampler_chain2")
  70. # Create NumpySlicesDataset with sampler chain
  71. # Use 2 statements to add child sampler
  72. np_data = [1, 2, 3, 4]
  73. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  74. child_sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  75. sampler.add_child(child_sampler)
  76. data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  77. # Verify dataset size
  78. data1_size = data1.get_dataset_size()
  79. logger.info("dataset size is: {}".format(data1_size))
  80. # FIXME: Uncomment the following assert when code issue is resolved; at runtime, data1_size is 2 not 4
  81. # assert data1_size == 4
  82. # Verify number of rows
  83. # FIXME: Uncomment the following assert when code issue is resolved; at runtime, number of rows is 2 not 4
  84. # assert sum([1 for _ in data1]) == 4
  85. # Verify dataset contents
  86. # FIXME: Uncomment the following test code when runtime code issue is resolved
  87. # res = []
  88. # for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  89. # logger.info("item: {}".format(item))
  90. # res.append(item)
  91. # logger.info("dataset: {}".format(res))
  92. def test_numpyslices_sampler_chain_batch():
  93. """
  94. Test NumpySlicesDataset sampler chaining, with batch
  95. """
  96. logger.info("test_numpyslices_sampler_chain_batch")
  97. # Create NumpySlicesDataset with sampler chain
  98. np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  99. sampler = ds.SequentialSampler(start_index=1, num_samples=3)
  100. sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  101. data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  102. data1 = data1.batch(batch_size=3, drop_remainder=False)
  103. # Verify dataset size
  104. data1_size = data1.get_dataset_size()
  105. logger.info("dataset size is: {}".format(data1_size))
  106. assert data1_size == 4
  107. # Verify number of rows
  108. assert sum([1 for _ in data1]) == 4
  109. # Verify dataset contents
  110. res = []
  111. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  112. logger.info("item: {}".format(item))
  113. res.append(item)
  114. logger.info("dataset: {}".format(res))
  115. def test_sampler_chain_errors():
  116. """
  117. Test error cases for sampler chains
  118. """
  119. logger.info("test_sampler_chain_errors")
  120. error_msg_1 = "'NoneType' object has no attribute 'add_child'"
  121. # Test add child sampler within child sampler
  122. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  123. sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  124. with pytest.raises(AttributeError, match=error_msg_1):
  125. sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  126. # error_msg_2 = "'NoneType' object has no attribute 'add_child'"
  127. # Test add second and nested child sampler
  128. sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  129. child_sampler = ds.SequentialSampler(start_index=1, num_samples=2)
  130. sampler.add_child(child_sampler)
  131. child_sampler2 = ds.SequentialSampler(start_index=1, num_samples=2)
  132. sampler.add_child(child_sampler2)
  133. # FIXME - no error is raised; uncomment after code issue is resolved
  134. # with pytest.raises(AttributeError, match=error_msg_2):
  135. # sampler.add_child(child_sampler2)
  136. # np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  137. # data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
  138. error_msg_3 = "Conflicting arguments during sampler assignments."
  139. # Test conflicting arguments (sampler and shuffle=False) for sampler (no chain)
  140. np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  141. sampler = ds.SequentialSampler(start_index=1, num_samples=3)
  142. with pytest.raises(ValueError, match=error_msg_3):
  143. ds.NumpySlicesDataset(np_data, shuffle=False, sampler=sampler)
  144. # error_msg_4 = "Conflicting arguments during sampler assignments."
  145. # Test conflicting arguments (sampler and shuffle=False) for sampler chaining
  146. np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  147. sampler = ds.SequentialSampler(start_index=1, num_samples=3)
  148. sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
  149. # FIXME - no error is raised; uncomment after code issue is resolved
  150. # with pytest.raises(ValueError, match=error_msg_4):
  151. # ds.NumpySlicesDataset(np_data, shuffle=False, sampler=sampler)
  152. def test_manifest_sampler_chain_repeat():
  153. """
  154. Test ManifestDataset sampler chain DistributedSampler->SequentialSampler, with repeat
  155. """
  156. logger.info("test_manifest_sampler_chain_batch")
  157. manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
  158. # Create sampler chain DistributedSampler->SequentialSampler
  159. sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=5)
  160. child_sampler = ds.SequentialSampler()
  161. sampler.add_child(child_sampler)
  162. # Create ManifestDataset with sampler chain
  163. data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
  164. data1 = data1.repeat(count=2)
  165. # Verify dataset size
  166. data1_size = data1.get_dataset_size()
  167. logger.info("dataset size is: {}".format(data1_size))
  168. assert data1_size == 10
  169. # Verify number of rows
  170. assert sum([1 for _ in data1]) == 10
  171. # Verify dataset contents
  172. filename = "sampler_chain_manifest_repeat_result.npz"
  173. save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
  174. def test_manifest_sampler_chain_batch_repeat():
  175. """
  176. Test ManifestDataset sampler chain DistributedSampler->SequentialSampler, with batch then repeat
  177. """
  178. logger.info("test_manifest_sampler_chain_batch_repeat")
  179. manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
  180. # Create sampler chain DistributedSampler->SequentialSampler
  181. sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=5)
  182. child_sampler = ds.SequentialSampler()
  183. sampler.add_child(child_sampler)
  184. # Create ManifestDataset with sampler chain
  185. data1 = ds.ManifestDataset(manifest_file, decode=True, sampler=sampler)
  186. one_hot_encode = c_transforms.OneHot(3)
  187. data1 = data1.map(operations=one_hot_encode, input_columns=["label"])
  188. data1 = data1.batch(batch_size=5, drop_remainder=False)
  189. data1 = data1.repeat(count=2)
  190. # Verify dataset size
  191. data1_size = data1.get_dataset_size()
  192. logger.info("dataset size is: {}".format(data1_size))
  193. assert data1_size == 2
  194. # Verify number of rows
  195. # FIXME: Uncomment the following assert when code issue is resolved
  196. # assert sum([1 for _ in data1]) == 2
  197. if __name__ == '__main__':
  198. test_numpyslices_sampler_no_chain()
  199. test_numpyslices_sampler_chain()
  200. test_numpyslices_sampler_chain2()
  201. test_numpyslices_sampler_chain_batch()
  202. test_sampler_chain_errors()
  203. test_manifest_sampler_chain_repeat()
  204. test_manifest_sampler_chain_batch_repeat()