You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dither.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.audio.transforms as audio
  19. from mindspore import log as logger
  20. from mindspore.dataset.audio.utils import DensityFunction
  21. from util import visualize_audio, diff_mse
  22. def count_unequal_element(data_expected, data_me, rtol, atol):
  23. assert data_expected.shape == data_me.shape
  24. total_count = len(data_expected.flatten())
  25. error = np.abs(data_expected - data_me)
  26. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  27. loss_count = np.count_nonzero(greater)
  28. assert (loss_count / total_count) < rtol, \
  29. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
  30. format(data_expected[greater], data_me[greater], error[greater])
  31. def test_dither_eager_noise_shaping_false():
  32. """
  33. Feature: Dither
  34. Description: test Dither in eager mode
  35. Expectation: the result is as expected
  36. """
  37. logger.info("test Dither in eager mode")
  38. # Original waveform
  39. waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
  40. # Expect waveform
  41. expect_waveform = np.array([[0.99993896, 1.99990845, 2.99984741],
  42. [3.99975586, 4.99972534, 5.99966431]], dtype=np.float64)
  43. dither_op = audio.Dither(DensityFunction.TPDF, False)
  44. # Filtered waveform by Dither
  45. output = dither_op(waveform)
  46. count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
  47. def test_dither_eager_noise_shaping_true():
  48. """
  49. Feature: Dither
  50. Description: test Dither in eager mode
  51. Expectation: the result is as expected
  52. """
  53. logger.info("test Dither in eager mode")
  54. # Original waveform
  55. waveform = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64)
  56. # Expect waveform
  57. expect_waveform = np.array([[0.9999, 1.9998, 2.9998],
  58. [3.9998, 4.9995, 5.9994],
  59. [6.9996, 7.9991, 8.9990]], dtype=np.float64)
  60. dither_op = audio.Dither(DensityFunction.TPDF, True)
  61. # Filtered waveform by Dither
  62. output = dither_op(waveform)
  63. count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
  64. def test_dither_pipeline(plot=False):
  65. """
  66. Feature: Dither
  67. Description: test Dither in pipeline mode
  68. Expectation: the result is as expected
  69. """
  70. logger.info("test Dither in pipeline mode")
  71. # Original waveform
  72. waveform_tpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
  73. [0.23657, 0.087965, 0.43579]], dtype=np.float64)
  74. waveform_rpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
  75. [0.23657, 0.087965, 0.43579]], dtype=np.float64)
  76. waveform_gpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
  77. [0.23657, 0.087965, 0.43579]], dtype=np.float64)
  78. # Expect waveform
  79. expect_tpdf = np.array([[0.49417114, 0.53909302, 0.48461914],
  80. [0.10839844, 0.02932739, 0.52352905],
  81. [0.23654175, 0.08798218, 0.43579102]], dtype=np.float64)
  82. expect_rpdf = np.array([[0.4941, 0.5391, 0.4846],
  83. [0.1084, 0.0293, 0.5235],
  84. [0.2365, 0.0880, 0.4358]], dtype=np.float64)
  85. expect_gpdf = np.array([[0.4944, 0.5393, 0.4848],
  86. [0.1086, 0.0295, 0.5237],
  87. [0.2368, 0.0882, 0.4360]], dtype=np.float64)
  88. dataset_tpdf = ds.NumpySlicesDataset(waveform_tpdf, ["audio"], shuffle=False)
  89. dataset_rpdf = ds.NumpySlicesDataset(waveform_rpdf, ["audio"], shuffle=False)
  90. dataset_gpdf = ds.NumpySlicesDataset(waveform_gpdf, ["audio"], shuffle=False)
  91. # Filtered waveform by Dither of TPDF
  92. dither_tpdf = audio.Dither()
  93. dataset_tpdf = dataset_tpdf.map(input_columns=["audio"], operations=dither_tpdf, num_parallel_workers=2)
  94. # Filtered waveform by Dither of RPDF
  95. dither_rpdf = audio.Dither(DensityFunction.RPDF, False)
  96. dataset_rpdf = dataset_rpdf.map(input_columns=["audio"], operations=dither_rpdf, num_parallel_workers=2)
  97. # Filtered waveform by Dither of GPDF
  98. dither_gpdf = audio.Dither(DensityFunction.GPDF, False)
  99. dataset_gpdf = dataset_gpdf.map(input_columns=["audio"], operations=dither_gpdf, num_parallel_workers=2)
  100. i = 0
  101. for data1, data2, data3 in zip(dataset_tpdf.create_dict_iterator(output_numpy=True),
  102. dataset_rpdf.create_dict_iterator(output_numpy=True),
  103. dataset_gpdf.create_dict_iterator(output_numpy=True)):
  104. count_unequal_element(expect_tpdf[i, :], data1['audio'], 0.0001, 0.0001)
  105. dither_rpdf = data2['audio']
  106. dither_gpdf = data3['audio']
  107. mse_rpdf = diff_mse(dither_rpdf, expect_rpdf[i, :])
  108. logger.info("dither_rpdf_{}, mse: {}".format(i + 1, mse_rpdf))
  109. mse_gpdf = diff_mse(dither_gpdf, expect_gpdf[i, :])
  110. logger.info("dither_gpdf_{}, mse: {}".format(i + 1, mse_gpdf))
  111. i += 1
  112. if plot:
  113. visualize_audio(dither_rpdf, expect_rpdf[i, :])
  114. visualize_audio(dither_gpdf, expect_gpdf[i, :])
  115. def test_invalid_dither_input():
  116. """
  117. Feature: Dither
  118. Description: test param check of Dither
  119. Expectation: throw correct error and message
  120. """
  121. logger.info("test param check of Dither")
  122. def test_invalid_input(test_name, density_function, noise_shaping, error, error_msg):
  123. logger.info("Test Dither with bad input: {0}".format(test_name))
  124. with pytest.raises(error) as error_info:
  125. audio.Dither(density_function, noise_shaping)
  126. assert error_msg in str(error_info.value)
  127. test_invalid_input("invalid density function parameter value", "TPDF", False, TypeError,
  128. "Argument density_function with value TPDF is not of type"
  129. + " [<DensityFunction.TPDF: 'TPDF'>, <DensityFunction.RPDF: 'RPDF'>"
  130. + ", <DensityFunction.GPDF: 'GPDF'>], but got <class 'str'>.")
  131. test_invalid_input("invalid density_function parameter value", 6, False, TypeError,
  132. "Argument density_function with value 6 is not of type"
  133. + " [<DensityFunction.TPDF: 'TPDF'>, <DensityFunction.RPDF: 'RPDF'>"
  134. + ", <DensityFunction.GPDF: 'GPDF'>], but got <class 'int'>.")
  135. test_invalid_input("invalid noise_shaping parameter value", DensityFunction.GPDF, 1, TypeError,
  136. "Argument noise_shaping with value 1 is not of type [<class 'bool'>], but got <class 'int'>.")
  137. test_invalid_input("invalid noise_shaping parameter value", DensityFunction.RPDF, "true", TypeError,
  138. "Argument noise_shaping with value true is not of type [<class 'bool'>], but got <class 'str'>")
  139. if __name__ == '__main__':
  140. test_dither_eager_noise_shaping_false()
  141. test_dither_eager_noise_shaping_true()
  142. test_dither_pipeline(plot=False)
  143. test_invalid_dither_input()