You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_melscale_fbanks.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset.audio.utils as audio
  18. from mindspore import log as logger
  19. def count_unequal_element(data_expected, data_me, rtol, atol):
  20. assert data_expected.shape == data_me.shape
  21. total_count = len(data_expected.flatten())
  22. error = np.abs(data_expected - data_me)
  23. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  24. loss_count = np.count_nonzero(greater)
  25. assert (loss_count / total_count) < rtol, \
  26. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
  27. format(data_expected[greater], data_me[greater], error[greater])
  28. def test_melscale_fbanks_normal():
  29. """
  30. Feature: melscale_fbanks.
  31. Description: Test normal operation with NormType.NONE and MelType.HTK.
  32. Expectation: The output data is the same as the result of torchaudio.functional.melscale_fbanks.
  33. """
  34. expect = np.array([[0.0000, 0.0000, 0.0000, 0.0000],
  35. [0.5502, 0.0000, 0.0000, 0.0000],
  36. [0.6898, 0.3102, 0.0000, 0.0000],
  37. [0.0000, 0.9366, 0.0634, 0.0000],
  38. [0.0000, 0.1924, 0.8076, 0.0000],
  39. [0.0000, 0.0000, 0.4555, 0.5445],
  40. [0.0000, 0.0000, 0.0000, 0.7247],
  41. [0.0000, 0.0000, 0.0000, 0.0000]], dtype=np.float64)
  42. output = audio.melscale_fbanks(8, 2, 50, 4, 100, audio.NormType.NONE, audio.MelType.HTK)
  43. count_unequal_element(expect, output, 0.0001, 0.0001)
  44. def test_melscale_fbanks_none_slaney():
  45. """
  46. Feature: melscale_fbanks.
  47. Description: Test normal operation with NormType.NONE and MelType.SLANEY.
  48. Expectation: The output data is the same as the result of torchaudio.functional.melscale_fbanks.
  49. """
  50. expect = np.array([[0.0000, 0.0000, 0.0000, 0.0000],
  51. [0.5357, 0.0000, 0.0000, 0.0000],
  52. [0.7202, 0.2798, 0.0000, 0.0000],
  53. [0.0000, 0.9762, 0.0238, 0.0000],
  54. [0.0000, 0.2321, 0.7679, 0.0000],
  55. [0.0000, 0.0000, 0.4881, 0.5119],
  56. [0.0000, 0.0000, 0.0000, 0.7440],
  57. [0.0000, 0.0000, 0.0000, 0.0000]], dtype=np.float64)
  58. output = audio.melscale_fbanks(8, 2, 50, 4, 100, audio.NormType.NONE, audio.MelType.SLANEY)
  59. count_unequal_element(expect, output, 0.0001, 0.0001)
  60. def test_melscale_fbanks_with_slaney_htk():
  61. """
  62. Feature: melscale_fbanks.
  63. Description: Test normal operation with NormType.SLANEY and MelType.HTK.
  64. Expectation: The output data is the same as the result of torchaudio.functional.melscale_fbanks.
  65. """
  66. output = audio.melscale_fbanks(10, 0, 50, 5, 100, audio.NormType.SLANEY, audio.MelType.HTK)
  67. expect = np.array([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
  68. [0.0843, 0.0000, 0.0000, 0.0000, 0.0000],
  69. [0.0776, 0.0447, 0.0000, 0.0000, 0.0000],
  70. [0.0000, 0.1158, 0.0055, 0.0000, 0.0000],
  71. [0.0000, 0.0344, 0.0860, 0.0000, 0.0000],
  72. [0.0000, 0.0000, 0.0741, 0.0454, 0.0000],
  73. [0.0000, 0.0000, 0.0000, 0.1133, 0.0053],
  74. [0.0000, 0.0000, 0.0000, 0.0355, 0.0822],
  75. [0.0000, 0.0000, 0.0000, 0.0000, 0.0760],
  76. [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], dtype=np.float64)
  77. count_unequal_element(expect, output, 0.0001, 0.0001)
  78. def test_melscale_fbanks_with_slaney_slaney():
  79. """
  80. Feature: melscale_fbanks.
  81. Description: Test normal operation with NormType.SLANEY and MelType.SLANEY.
  82. Expectation: The output data is the same as the result of torchaudio.functional.melscale_fbanks.
  83. """
  84. output = audio.melscale_fbanks(8, 2, 50, 4, 100, audio.NormType.SLANEY, audio.MelType.SLANEY)
  85. expect = np.array([[0.0000, 0.0000, 0.0000, 0.0000],
  86. [0.0558, 0.0000, 0.0000, 0.0000],
  87. [0.0750, 0.0291, 0.0000, 0.0000],
  88. [0.0000, 0.1017, 0.0025, 0.0000],
  89. [0.0000, 0.0242, 0.0800, 0.0000],
  90. [0.0000, 0.0000, 0.0508, 0.0533],
  91. [0.0000, 0.0000, 0.0000, 0.0775],
  92. [0.0000, 0.0000, 0.0000, 0.0000]], dtype=np.float64)
  93. count_unequal_element(expect, output, 0.0001, 0.0001)
  94. def test_melscale_fbanks_invalid_input():
  95. """
  96. Feature: melscale_fbanks.
  97. Description: Test operation with invalid input.
  98. Expectation: Throw exception as expected.
  99. """
  100. def test_invalid_input(test_name, n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type, error, error_msg):
  101. logger.info("Test melscale_fbanks with bad input: {0}".format(test_name))
  102. with pytest.raises(error) as error_info:
  103. audio.melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type)
  104. print(error_info)
  105. assert error_msg in str(error_info.value)
  106. test_invalid_input("invalid n_freqs parameter Value", 99999999999, 0, 50, 5, 100, audio.NormType.NONE,
  107. audio.MelType.HTK, ValueError, "n_freqs")
  108. test_invalid_input("invalid n_freqs parameter type", 10.5, 0, 50, 5, 100, audio.NormType.NONE, audio.MelType.HTK,
  109. TypeError, "n_freqs")
  110. test_invalid_input("invalid f_min parameter type", 10, None, 50, 5, 100, audio.NormType.NONE, audio.MelType.HTK,
  111. TypeError, "f_min")
  112. test_invalid_input("invalid f_max parameter type", 10, 0, None, 5, 100, audio.NormType.NONE, audio.MelType.HTK,
  113. TypeError, "f_max")
  114. test_invalid_input("invalid n_mels parameter type", 10, 0, 50, 10.1, 100, audio.NormType.NONE, audio.MelType.HTK,
  115. TypeError, "n_mels")
  116. test_invalid_input("invalid n_mels parameter Value", 20, 0, 50, 999999999999, 100, audio.NormType.NONE,
  117. audio.MelType.HTK, ValueError, "n_mels")
  118. test_invalid_input("invalid sample_rate parameter type", 10, 0, 50, 5, 100.1, audio.NormType.NONE,
  119. audio.MelType.HTK, TypeError, "sample_rate")
  120. test_invalid_input("invalid sample_rate parameter Value", 20, 0, 50, 5, 999999999999, audio.NormType.NONE,
  121. audio.MelType.HTK, ValueError, "sample_rate")
  122. test_invalid_input("invalid norm parameter type", 10, 0, 50, 5, 100, None, audio.MelType.HTK,
  123. TypeError, "norm")
  124. test_invalid_input("invalid norm parameter type", 10, 0, 50, 5, 100, audio.NormType.SLANEY, None,
  125. TypeError, "mel_type")
  126. if __name__ == "__main__":
  127. test_melscale_fbanks_normal()
  128. test_melscale_fbanks_none_slaney()
  129. test_melscale_fbanks_with_slaney_htk()
  130. test_melscale_fbanks_with_slaney_slaney()
  131. test_melscale_fbanks_invalid_input()