You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_yes_no.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.audio.transforms as audio
  19. from mindspore import log as logger
  20. DATA_DIR = "../data/dataset/testYesNoData/"
  21. def test_yes_no_basic():
  22. """
  23. Feature: YesNo Dataset
  24. Description: Read all files
  25. Expectation: Output the amount of file
  26. """
  27. logger.info("Test YesNoDataset Op")
  28. data = ds.YesNoDataset(DATA_DIR)
  29. num_iter = 0
  30. for _ in data.create_dict_iterator(num_epochs=1):
  31. num_iter += 1
  32. assert num_iter == 3
  33. def test_yes_no_num_samples():
  34. """
  35. Feature: YesNo Dataset
  36. Description: Test num_samples
  37. Expectation: Get certain number of samples
  38. """
  39. data = ds.YesNoDataset(DATA_DIR, num_samples=2)
  40. num_iter = 0
  41. for _ in data.create_dict_iterator(num_epochs=1):
  42. num_iter += 1
  43. assert num_iter == 2
  44. def test_yes_no_repeat():
  45. """
  46. Feature: YesNo Dataset
  47. Description: Test repeat
  48. Expectation: Output the amount of file
  49. """
  50. data = ds.YesNoDataset(DATA_DIR, num_samples=2)
  51. data = data.repeat(5)
  52. num_iter = 0
  53. for _ in data.create_dict_iterator(num_epochs=1):
  54. num_iter += 1
  55. assert num_iter == 10
  56. def test_yes_no_dataset_size():
  57. """
  58. Feature: YesNo Dataset
  59. Description: Test dataset_size
  60. Expectation: Output the size of dataset
  61. """
  62. data = ds.YesNoDataset(DATA_DIR, shuffle=False)
  63. assert data.get_dataset_size() == 3
  64. def test_yes_no_sequential_sampler():
  65. """
  66. Feature: YesNo Dataset
  67. Description: Use SequentialSampler to sample data.
  68. Expectation: The number of samplers returned by dict_iterator is equal to the requested number of samples.
  69. """
  70. logger.info("Test YesNoDataset Op with SequentialSampler")
  71. num_samples = 2
  72. sampler = ds.SequentialSampler(num_samples=num_samples)
  73. data1 = ds.YesNoDataset(DATA_DIR, sampler=sampler)
  74. data2 = ds.YesNoDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
  75. sample_rate_list1, sample_rate_list2 = [], []
  76. num_iter = 0
  77. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1),
  78. data2.create_dict_iterator(num_epochs=1)):
  79. sample_rate_list1.append(item1["sample_rate"])
  80. sample_rate_list2.append(item2["sample_rate"])
  81. num_iter += 1
  82. np.testing.assert_array_equal(sample_rate_list1, sample_rate_list2)
  83. assert num_iter == num_samples
  84. def test_yes_no_exception():
  85. """
  86. Feature: Error tests
  87. Description: Throw error messages when certain errors occur
  88. Expectation: Output error message
  89. """
  90. logger.info("Test error cases for YesNoDataset")
  91. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  92. with pytest.raises(RuntimeError, match=error_msg_1):
  93. ds.YesNoDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
  94. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  95. with pytest.raises(RuntimeError, match=error_msg_2):
  96. ds.YesNoDataset(DATA_DIR, sampler=ds.PKSampler(3),
  97. num_shards=2, shard_id=0)
  98. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  99. with pytest.raises(RuntimeError, match=error_msg_3):
  100. ds.YesNoDataset(DATA_DIR, num_shards=10)
  101. error_msg_4 = "shard_id is specified but num_shards is not"
  102. with pytest.raises(RuntimeError, match=error_msg_4):
  103. ds.YesNoDataset(DATA_DIR, shard_id=0)
  104. error_msg_5 = "Input shard_id is not within the required interval"
  105. with pytest.raises(ValueError, match=error_msg_5):
  106. ds.YesNoDataset(DATA_DIR, num_shards=5, shard_id=-1)
  107. with pytest.raises(ValueError, match=error_msg_5):
  108. ds.YesNoDataset(DATA_DIR, num_shards=5, shard_id=5)
  109. with pytest.raises(ValueError, match=error_msg_5):
  110. ds.YesNoDataset(DATA_DIR, num_shards=2, shard_id=5)
  111. error_msg_6 = "num_parallel_workers exceeds"
  112. with pytest.raises(ValueError, match=error_msg_6):
  113. ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
  114. with pytest.raises(ValueError, match=error_msg_6):
  115. ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
  116. with pytest.raises(ValueError, match=error_msg_6):
  117. ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
  118. error_msg_7 = "Argument shard_id"
  119. with pytest.raises(TypeError, match=error_msg_7):
  120. ds.YesNoDataset(DATA_DIR, num_shards=2, shard_id="0")
  121. def exception_func(item):
  122. raise Exception("Error occur!")
  123. error_msg_8 = "The corresponding data files"
  124. with pytest.raises(RuntimeError, match=error_msg_8):
  125. data = ds.YesNoDataset(DATA_DIR)
  126. data = data.map(operations=exception_func, input_columns=[
  127. "waveform"], num_parallel_workers=1)
  128. for _ in data.__iter__():
  129. pass
  130. with pytest.raises(RuntimeError, match=error_msg_8):
  131. data = ds.YesNoDataset(DATA_DIR)
  132. data = data.map(operations=exception_func, input_columns=[
  133. "sample_rate"], num_parallel_workers=1)
  134. for _ in data.__iter__():
  135. pass
  136. def test_yes_no_pipeline():
  137. """
  138. Feature: Pipeline test
  139. Description: Read a sample
  140. Expectation: The amount of each function are equal
  141. """
  142. # Original waveform
  143. dataset = ds.YesNoDataset(DATA_DIR, num_samples=1)
  144. band_biquad_op = audio.BandBiquad(8000, 200.0)
  145. # Filtered waveform by bandbiquad
  146. dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=2)
  147. num_iter = 0
  148. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  149. num_iter += 1
  150. assert num_iter == 1
  151. if __name__ == '__main__':
  152. test_yes_no_basic()
  153. test_yes_no_num_samples()
  154. test_yes_no_repeat()
  155. test_yes_no_dataset_size()
  156. test_yes_no_sequential_sampler()
  157. test_yes_no_exception()
  158. test_yes_no_pipeline()