You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_lj_speech.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test LJSpeech dataset operators
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.audio.transforms as audio
  22. from mindspore import log as logger
  23. DATA_DIR = "../data/dataset/testLJSpeechData/"
  24. def test_lj_speech_basic():
  25. """
  26. Feature: LJSpeechDataset
  27. Description: basic test of LJSpeechDataset
  28. Expectation: the data is processed successfully
  29. """
  30. logger.info("Test LJSpeechDataset Op")
  31. # case 1: test loading whole dataset
  32. data1 = ds.LJSpeechDataset(DATA_DIR)
  33. num_iter1 = 0
  34. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  35. num_iter1 += 1
  36. assert num_iter1 == 3
  37. # case 2: test num_samples
  38. data2 = ds.LJSpeechDataset(DATA_DIR, num_samples=3)
  39. num_iter2 = 0
  40. for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  41. num_iter2 += 1
  42. assert num_iter2 == 3
  43. # case 3: test repeat
  44. data3 = ds.LJSpeechDataset(DATA_DIR, num_samples=3)
  45. data3 = data3.repeat(5)
  46. num_iter3 = 0
  47. for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  48. num_iter3 += 1
  49. assert num_iter3 == 15
  50. def test_lj_speech_sequential_sampler():
  51. """
  52. Feature: LJSpeechDataset
  53. Description: test LJSpeechDataset with SequentialSampler
  54. Expectation: the data is processed successfully
  55. """
  56. logger.info("Test LJSpeechDataset Op with SequentialSampler")
  57. num_samples = 3
  58. sampler = ds.SequentialSampler(num_samples=num_samples)
  59. data1 = ds.LJSpeechDataset(DATA_DIR, sampler=sampler)
  60. data2 = ds.LJSpeechDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
  61. sample_rate_list1, sample_rate_list2 = [], []
  62. num_iter = 0
  63. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  64. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  65. sample_rate_list1.append(item1["sample_rate"])
  66. sample_rate_list2.append(item2["sample_rate"])
  67. num_iter += 1
  68. np.testing.assert_array_equal(sample_rate_list1, sample_rate_list2)
  69. assert num_iter == num_samples
  70. def test_lj_speech_exception():
  71. """
  72. Feature: LJSpeechDataset
  73. Description: test error cases for LJSpeechDataset
  74. Expectation: throw correct error and message
  75. """
  76. logger.info("Test error cases for LJSpeechDataset")
  77. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  78. with pytest.raises(RuntimeError, match=error_msg_1):
  79. ds.LJSpeechDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
  80. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  81. with pytest.raises(RuntimeError, match=error_msg_2):
  82. ds.LJSpeechDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  83. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  84. with pytest.raises(RuntimeError, match=error_msg_3):
  85. ds.LJSpeechDataset(DATA_DIR, num_shards=10)
  86. error_msg_4 = "shard_id is specified but num_shards is not"
  87. with pytest.raises(RuntimeError, match=error_msg_4):
  88. ds.LJSpeechDataset(DATA_DIR, shard_id=0)
  89. error_msg_5 = "Input shard_id is not within the required interval"
  90. with pytest.raises(ValueError, match=error_msg_5):
  91. ds.LJSpeechDataset(DATA_DIR, num_shards=5, shard_id=-1)
  92. with pytest.raises(ValueError, match=error_msg_5):
  93. ds.LJSpeechDataset(DATA_DIR, num_shards=5, shard_id=5)
  94. with pytest.raises(ValueError, match=error_msg_5):
  95. ds.LJSpeechDataset(DATA_DIR, num_shards=2, shard_id=5)
  96. error_msg_6 = "num_parallel_workers exceeds"
  97. with pytest.raises(ValueError, match=error_msg_6):
  98. ds.LJSpeechDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
  99. with pytest.raises(ValueError, match=error_msg_6):
  100. ds.LJSpeechDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
  101. with pytest.raises(ValueError, match=error_msg_6):
  102. ds.LJSpeechDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
  103. error_msg_7 = "Argument shard_id"
  104. with pytest.raises(TypeError, match=error_msg_7):
  105. ds.LJSpeechDataset(DATA_DIR, num_shards=2, shard_id="0")
  106. def exception_func(item):
  107. raise Exception("Error occur!")
  108. error_msg_8 = "The corresponding data files"
  109. with pytest.raises(RuntimeError, match=error_msg_8):
  110. data = ds.LJSpeechDataset(DATA_DIR)
  111. data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1)
  112. for _ in data.__iter__():
  113. pass
  114. with pytest.raises(RuntimeError, match=error_msg_8):
  115. data = ds.LJSpeechDataset(DATA_DIR)
  116. data = data.map(operations=exception_func, input_columns=["sample_rate"], num_parallel_workers=1)
  117. for _ in data.__iter__():
  118. pass
  119. def test_lj_speech_pipeline():
  120. """
  121. Feature: LJSpeechDataset
  122. Description: Read a sample
  123. Expectation: The amount of each function are equal
  124. """
  125. # Original waveform
  126. dataset = ds.LJSpeechDataset(DATA_DIR)
  127. band_biquad_op = audio.BandBiquad(8000, 200.0)
  128. # Filtered waveform by bandbiquad
  129. dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=2)
  130. i = 0
  131. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  132. i += 1
  133. assert i == 3
  134. if __name__ == '__main__':
  135. test_lj_speech_basic()
  136. test_lj_speech_sequential_sampler()
  137. test_lj_speech_exception()
  138. test_lj_speech_pipeline()