You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_speech_commands.py 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test SpeechCommands dataset operators
  17. """
  18. import pytest
  19. import numpy as np
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.audio.transforms as audio
  22. from mindspore import log as logger
  23. DATA_DIR = "../data/dataset/testSpeechCommandsData/"
  24. def test_speech_commands_basic():
  25. """
  26. Feature: SpeechCommands Dataset
  27. Description: Read all files
  28. Expectation: Output the amount of files
  29. """
  30. logger.info("Test SpeechCommandsDataset Op.")
  31. # case 1: test loading whole dataset
  32. data1 = ds.SpeechCommandsDataset(DATA_DIR)
  33. num_iter1 = 0
  34. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  35. num_iter1 += 1
  36. assert num_iter1 == 3
  37. # case 2: test num_samples
  38. data2 = ds.SpeechCommandsDataset(DATA_DIR, num_samples=3)
  39. num_iter2 = 0
  40. for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  41. num_iter2 += 1
  42. assert num_iter2 == 3
  43. # case 3: test repeat
  44. data3 = ds.SpeechCommandsDataset(DATA_DIR, num_samples=2)
  45. data3 = data3.repeat(5)
  46. num_iter3 = 0
  47. for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  48. num_iter3 += 1
  49. assert num_iter3 == 10
  50. def test_speech_commands_sequential_sampler():
  51. """
  52. Feature: SpeechCommands Dataset
  53. Description: Use SequentialSampler to sample data.
  54. Expectation: The number of samplers returned by dict_iterator is equal to the requested number of samples.
  55. """
  56. logger.info("Test SpeechCommandsDataset with SequentialSampler.")
  57. num_samples = 2
  58. sampler = ds.SequentialSampler(num_samples=num_samples)
  59. data1 = ds.SpeechCommandsDataset(DATA_DIR, sampler=sampler)
  60. data2 = ds.SpeechCommandsDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
  61. sample_rate_list1, sample_rate_list2 = [], []
  62. num_iter = 0
  63. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  64. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  65. sample_rate_list1.append(item1["sample_rate"])
  66. sample_rate_list2.append(item2["sample_rate"])
  67. num_iter += 1
  68. np.testing.assert_array_equal(sample_rate_list1, sample_rate_list2)
  69. assert num_iter == num_samples
  70. def test_speech_commands_exception():
  71. """
  72. Feature: SpeechCommands Dataset
  73. Description: Test error cases for SpeechCommandsDataset
  74. Expectation: Error message
  75. """
  76. logger.info("Test error cases for SpeechCommandsDataset.")
  77. error_msg_1 = "sampler and shuffle cannot be specified at the same time."
  78. with pytest.raises(RuntimeError, match=error_msg_1):
  79. ds.SpeechCommandsDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
  80. error_msg_2 = "sampler and sharding cannot be specified at the same time."
  81. with pytest.raises(RuntimeError, match=error_msg_2):
  82. ds.SpeechCommandsDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  83. error_msg_3 = "num_shards is specified and currently requires shard_id as well."
  84. with pytest.raises(RuntimeError, match=error_msg_3):
  85. ds.SpeechCommandsDataset(DATA_DIR, num_shards=10)
  86. error_msg_4 = "shard_id is specified but num_shards is not."
  87. with pytest.raises(RuntimeError, match=error_msg_4):
  88. ds.SpeechCommandsDataset(DATA_DIR, shard_id=0)
  89. error_msg_5 = "Input shard_id is not within the required interval."
  90. with pytest.raises(ValueError, match=error_msg_5):
  91. ds.SpeechCommandsDataset(DATA_DIR, num_shards=5, shard_id=-1)
  92. with pytest.raises(ValueError, match=error_msg_5):
  93. ds.SpeechCommandsDataset(DATA_DIR, num_shards=5, shard_id=5)
  94. with pytest.raises(ValueError, match=error_msg_5):
  95. ds.SpeechCommandsDataset(DATA_DIR, num_shards=2, shard_id=5)
  96. error_msg_6 = "num_parallel_workers exceeds."
  97. with pytest.raises(ValueError, match=error_msg_6):
  98. ds.SpeechCommandsDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
  99. with pytest.raises(ValueError, match=error_msg_6):
  100. ds.SpeechCommandsDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
  101. with pytest.raises(ValueError, match=error_msg_6):
  102. ds.SpeechCommandsDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
  103. error_msg_7 = "Argument shard_id."
  104. with pytest.raises(TypeError, match=error_msg_7):
  105. ds.SpeechCommandsDataset(DATA_DIR, num_shards=2, shard_id="0")
  106. def exception_func(item):
  107. raise Exception("Error occur!")
  108. error_msg_8 = "The corresponding data files."
  109. with pytest.raises(RuntimeError, match=error_msg_8):
  110. data = ds.SpeechCommandsDataset(DATA_DIR)
  111. data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1)
  112. for _ in data.__iter__():
  113. pass
  114. with pytest.raises(RuntimeError, match=error_msg_8):
  115. data = ds.SpeechCommandsDataset(DATA_DIR)
  116. data = data.map(operations=exception_func, input_columns=["sample_rate"], num_parallel_workers=1)
  117. for _ in data.__iter__():
  118. pass
  119. def test_speech_commands_usage():
  120. """
  121. Feature: SpeechCommands Dataset
  122. Description: Usage Test
  123. Expectation: Get the result of each function
  124. """
  125. logger.info("Test SpeechCommandsDataset usage flag.")
  126. def test_config(usage, speech_commands_path=DATA_DIR):
  127. try:
  128. data = ds.SpeechCommandsDataset(speech_commands_path, usage=usage)
  129. num_rows = 0
  130. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  131. num_rows += 1
  132. except (ValueError, TypeError, RuntimeError) as e:
  133. return str(e)
  134. return num_rows
  135. # test the usage of SpeechCommands
  136. assert test_config("test") == 1
  137. assert test_config("train") == 1
  138. assert test_config("valid") == 1
  139. assert test_config("all") == 3
  140. assert "usage is not within the valid set of ['train', 'test', 'valid', 'all']." in test_config("invalid")
  141. # change this directory to the folder that contains all SpeechCommands files
  142. all_speech_commands = None
  143. if all_speech_commands is not None:
  144. assert test_config("test", all_speech_commands) == 11005
  145. assert test_config("valid", all_speech_commands) == 9981
  146. assert test_config("train", all_speech_commands) == 84843
  147. assert test_config("all", all_speech_commands) == 105829
  148. assert ds.SpeechCommandsDataset(all_speech_commands, usage="test").get_dataset_size() == 11005
  149. assert ds.SpeechCommandsDataset(all_speech_commands, usage="valid").get_dataset_size() == 9981
  150. assert ds.SpeechCommandsDataset(all_speech_commands, usage="train").get_dataset_size() == 84843
  151. assert ds.SpeechCommandsDataset(all_speech_commands, usage="all").get_dataset_size() == 105829
  152. def test_speech_commands_pipeline():
  153. """
  154. Feature: Pipeline test
  155. Description: Read a sample
  156. Expectation: Test BandBiquad by pipeline
  157. """
  158. dataset = ds.SpeechCommandsDataset(DATA_DIR, num_samples=1)
  159. band_biquad_op = audio.BandBiquad(8000, 200.0)
  160. # Filtered waveform by bandbiquad
  161. dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=4)
  162. i = 0
  163. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  164. i += 1
  165. assert i == 1
  166. if __name__ == '__main__':
  167. test_speech_commands_basic()
  168. test_speech_commands_sequential_sampler()
  169. test_speech_commands_exception()
  170. test_speech_commands_usage()
  171. test_speech_commands_pipeline()